Как ускорить эту функцию Python с Numba?

Я пытаюсь ускорить эту функцию Python:

def twoFreq_orig(z, source_z, num, den, matrix, e):
    Z1, Z2 = np.meshgrid(source_z, np.conj(z))
    Z1 **= num
    Z2 **= den - 1
    M = (e ** ((num + den - 2) / 2.0)) * Z1 * Z2
    return np.sum(matrix * M, 1)

где z а также source_z являются np.ndarray (1d, dtype=np.complex128), num а также den являются np.ndarray (2d, dtype=np.float64), matrix это np.ndarray (2d, dtype=np.complex128) а также e это np.float64,

У меня нет большого опыта работы с Numba, но после прочтения некоторых уроков я придумал эту реализацию:

@nb.jit(nb.f8[:](nb.c16[:], nb.c16[:], nb.f8[:, :], nb.f8[:, :], nb.c16[:, :], nb.f8))
def twoFreq(z, source_z, num, den, matrix, e):
    N1, N2 = len(z), len(source_z)
    out = np.zeros(N1)
    for r in xrange(N1):
        tmp = 0
        for c in xrange(N2):
            n, d = num[r, c], den[r, c] - 1
            z1 = source_z[c] ** n
            z2 = z[r] ** d
            tmp += matrix[r, c] * e ** ((n + d - 1) / 2.0) * z1 * z2
        out[r] = tmp
    return out

К сожалению, вместо ускорения реализация Numba в несколько раз медленнее, чем оригинал. Я не могу понять, как правильно использовать Numba. Любой гуру Нумбы, который может мне помочь?

1 ответ

Решение

На самом деле, я не думаю, что вы можете многое сделать для ускорения функции numba, не разбираясь в свойствах ваших массивов (есть ли некоторые математические приемы, чтобы ускорить выполнение некоторых вычислений).

Но я заметил одну ошибку: вы не связали свой массив, например, в версии numba, и я отредактировал некоторые строки, чтобы сделать его более упорядоченным (некоторые из которых могут быть только вкусом). Я включил комментарии по соответствующим местам:

@nb.njit
def twoFreq(z, source_z, num, den, matrix, e):
    #Replace z with conjugate of z (otherwise the result is wrong!)
    z = np.conj(z)
    # Size instead of len() don't know if it actually makes a difference but it's cleaner
    N1, N2 = z.size, source_z.size
    # Must be zeros_like otherwise you create a float array where you want a complex one
    out = np.zeros_like(z)
    # I'm using python 3 so you need to replace this by xrange later
    for r in range(N1):
        for c in range(N2):
            n, d = num[r, c], den[r, c] - 1
            z1 = source_z[c] ** n
            z2 = z[r] ** d
            # Multiply with 0.5 instead of dividing by 2
            # Work on the out array directly instead of a tmp variable
            out[r] += matrix[r, c] * e ** ((n + d - 1) * 0.5) * z1 * z2
    return out

def twoFreq_orig(z, source_z, num, den, matrix, e):
    Z1, Z2 = np.meshgrid(source_z, np.conj(z))
    Z1 **= num
    Z2 **= den - 1
    M = (e ** ((num + den - 2) / 2.0)) * Z1 * Z2
    return np.sum(matrix * M, 1)


numb = 1000
z = np.random.uniform(0,1,numb) + 1j*np.random.uniform(0,1,numb)
source_z = np.random.uniform(0,10,numb) + 1j*np.random.uniform(0,1,numb)
num = np.random.uniform(0,1,(numb,numb))
den = np.random.uniform(0,1,(numb,numb))
matrix = np.random.uniform(0,1,(numb,numb)) + 1j*np.random.uniform(0,1,(numb, numb))
e = 5.5

# This failed for your initial version:
np.testing.assert_array_almost_equal(twoFreq(z, source_z, num, den, matrix, e),
                                     twoFreq_orig(z, source_z, num, den, matrix, e))

И время выполнения на моем компьютере было:

%timeit twoFreq(z, source_z, num, den, matrix, e)

1 цикл, лучшее из 3: 246 мс на цикл

%timeit twoFreq_orig(z, source_z, num, den, matrix, e)

1 цикл, лучшее из 3: 344 мс на цикл

Это примерно на 30% быстрее, чем ваш тупой раствор. Но я думаю, что решение можно было бы сделать немного быстрее с умным использованием вещания. Но, тем не менее, большая часть ускорения, которое я получил, была от пропуска подписи: обратите внимание, что вы, вероятно, используете C-смежные массивы, но вы задали произвольный порядок (так что numba может быть немного медленнее в зависимости от архитектуры компьютера). Вероятно, определяя c16[::-1] вы получите ту же скорость, но обычно просто позволяете numba определять тип, вероятно, он будет настолько быстрым, насколько это возможно. Исключение: вам нужны разные входы точности для каждой переменной (например, вы хотите z быть complex128 а также complex64)

Вы получите невероятное ускорение, когда у вашего numy-решения закончится память (потому что ваше numy-решение векторизовано, ему потребуется гораздо больше оперативной памяти!) numb = 5000 версия numba была примерно в 3 раза быстрее, чем numpy.


РЕДАКТИРОВАТЬ:

С умной трансляцией я имею ввиду

np.conj(z[:,None]**(den-1)) * source_z[None, :]**(num)

равно

z1, z2 = np.meshgrid(source_z, np.conj(z))
z1**(num) * z2**(den-1)

но с первым вариантом у вас есть только силовая операция на numb элементы, тогда как у вас есть (numb, numb) массив, поэтому вы выполняете гораздо больше "мощных" операций, чем необходимо (хотя я думаю, что для небольших массивов результат, вероятно, в основном кэшируется и не очень дорогой)

Версия для NumPy без mgrid (который дает тот же результат) выглядит так:

def twoFreq_orig2(z, source_z, num, den, matrix, e):
    z1z2 = source_z[None,:]**(num) * np.conj(z)[:, None]**(den-1)
    M = (e ** ((num + den - 2) / 2.0)) * z1z2
    return np.sum(matrix * M, 1)
Другие вопросы по тегам