Как ускорить эту функцию Python с Numba?
Я пытаюсь ускорить эту функцию Python:
def twoFreq_orig(z, source_z, num, den, matrix, e):
Z1, Z2 = np.meshgrid(source_z, np.conj(z))
Z1 **= num
Z2 **= den - 1
M = (e ** ((num + den - 2) / 2.0)) * Z1 * Z2
return np.sum(matrix * M, 1)
где z
а также source_z
являются np.ndarray
(1d, dtype=np.complex128
), num
а также den
являются np.ndarray
(2d, dtype=np.float64
), matrix
это np.ndarray
(2d, dtype=np.complex128
) а также e
это np.float64
,
У меня нет большого опыта работы с Numba, но после прочтения некоторых уроков я придумал эту реализацию:
@nb.jit(nb.f8[:](nb.c16[:], nb.c16[:], nb.f8[:, :], nb.f8[:, :], nb.c16[:, :], nb.f8))
def twoFreq(z, source_z, num, den, matrix, e):
N1, N2 = len(z), len(source_z)
out = np.zeros(N1)
for r in xrange(N1):
tmp = 0
for c in xrange(N2):
n, d = num[r, c], den[r, c] - 1
z1 = source_z[c] ** n
z2 = z[r] ** d
tmp += matrix[r, c] * e ** ((n + d - 1) / 2.0) * z1 * z2
out[r] = tmp
return out
К сожалению, вместо ускорения реализация Numba в несколько раз медленнее, чем оригинал. Я не могу понять, как правильно использовать Numba. Любой гуру Нумбы, который может мне помочь?
1 ответ
На самом деле, я не думаю, что вы можете многое сделать для ускорения функции numba, не разбираясь в свойствах ваших массивов (есть ли некоторые математические приемы, чтобы ускорить выполнение некоторых вычислений).
Но я заметил одну ошибку: вы не связали свой массив, например, в версии numba, и я отредактировал некоторые строки, чтобы сделать его более упорядоченным (некоторые из которых могут быть только вкусом). Я включил комментарии по соответствующим местам:
@nb.njit
def twoFreq(z, source_z, num, den, matrix, e):
#Replace z with conjugate of z (otherwise the result is wrong!)
z = np.conj(z)
# Size instead of len() don't know if it actually makes a difference but it's cleaner
N1, N2 = z.size, source_z.size
# Must be zeros_like otherwise you create a float array where you want a complex one
out = np.zeros_like(z)
# I'm using python 3 so you need to replace this by xrange later
for r in range(N1):
for c in range(N2):
n, d = num[r, c], den[r, c] - 1
z1 = source_z[c] ** n
z2 = z[r] ** d
# Multiply with 0.5 instead of dividing by 2
# Work on the out array directly instead of a tmp variable
out[r] += matrix[r, c] * e ** ((n + d - 1) * 0.5) * z1 * z2
return out
def twoFreq_orig(z, source_z, num, den, matrix, e):
Z1, Z2 = np.meshgrid(source_z, np.conj(z))
Z1 **= num
Z2 **= den - 1
M = (e ** ((num + den - 2) / 2.0)) * Z1 * Z2
return np.sum(matrix * M, 1)
numb = 1000
z = np.random.uniform(0,1,numb) + 1j*np.random.uniform(0,1,numb)
source_z = np.random.uniform(0,10,numb) + 1j*np.random.uniform(0,1,numb)
num = np.random.uniform(0,1,(numb,numb))
den = np.random.uniform(0,1,(numb,numb))
matrix = np.random.uniform(0,1,(numb,numb)) + 1j*np.random.uniform(0,1,(numb, numb))
e = 5.5
# This failed for your initial version:
np.testing.assert_array_almost_equal(twoFreq(z, source_z, num, den, matrix, e),
twoFreq_orig(z, source_z, num, den, matrix, e))
И время выполнения на моем компьютере было:
%timeit twoFreq(z, source_z, num, den, matrix, e)
1 цикл, лучшее из 3: 246 мс на цикл
%timeit twoFreq_orig(z, source_z, num, den, matrix, e)
1 цикл, лучшее из 3: 344 мс на цикл
Это примерно на 30% быстрее, чем ваш тупой раствор. Но я думаю, что решение можно было бы сделать немного быстрее с умным использованием вещания. Но, тем не менее, большая часть ускорения, которое я получил, была от пропуска подписи: обратите внимание, что вы, вероятно, используете C-смежные массивы, но вы задали произвольный порядок (так что numba может быть немного медленнее в зависимости от архитектуры компьютера). Вероятно, определяя c16[::-1]
вы получите ту же скорость, но обычно просто позволяете numba определять тип, вероятно, он будет настолько быстрым, насколько это возможно. Исключение: вам нужны разные входы точности для каждой переменной (например, вы хотите z
быть complex128
а также complex64
)
Вы получите невероятное ускорение, когда у вашего numy-решения закончится память (потому что ваше numy-решение векторизовано, ему потребуется гораздо больше оперативной памяти!) numb = 5000
версия numba была примерно в 3 раза быстрее, чем numpy.
РЕДАКТИРОВАТЬ:
С умной трансляцией я имею ввиду
np.conj(z[:,None]**(den-1)) * source_z[None, :]**(num)
равно
z1, z2 = np.meshgrid(source_z, np.conj(z))
z1**(num) * z2**(den-1)
но с первым вариантом у вас есть только силовая операция на numb
элементы, тогда как у вас есть (numb, numb)
массив, поэтому вы выполняете гораздо больше "мощных" операций, чем необходимо (хотя я думаю, что для небольших массивов результат, вероятно, в основном кэшируется и не очень дорогой)
Версия для NumPy без mgrid
(который дает тот же результат) выглядит так:
def twoFreq_orig2(z, source_z, num, den, matrix, e):
z1z2 = source_z[None,:]**(num) * np.conj(z)[:, None]**(den-1)
M = (e ** ((num + den - 2) / 2.0)) * z1z2
return np.sum(matrix * M, 1)