Почему в этом коде numba имеет худшую оптимизацию, чем Cython?

Я пытаюсь оптимизировать код с помощью Numba. Проблема в том, что простая оптимизация Cython (просто указание типов данных) в шесть раз быстрее, чем использование autojit, поэтому я не знаю, делаю ли я что-то не так.

Функция для оптимизации:

from numba import autojit

@autojit(nopython=True)
def get_energy(system, i,j,m): 
  #system is an array, (i,j) some indices and m the size of the array
  up=i-1;  down=i+1;  left=j-1;  right=j+1
  if up<0: total=system[m,j]
  else: total=system[up,j]
  if down>m: total+=system[0,j]
  else: total+=system[down,j]
  if left<0: total+=system[i,m]
  else: total+=system[i,left]
  if right>m: total+=system[i,0]
  else: total+=system[i,right]
  return 2*system[i,j]*total

Простой прогон будет примерно таким:

import numpy as np
x=np.random.rand(50,50)
get_energy(x, 3, 5, 50)

Я понял, что numba хорош в циклах, но может не очень хорошо оптимизировать другие вещи. В любом случае, я ожидаю, что производительность, аналогичная Cython, будет медленнее при доступе к массивам или при условных выражениях?

Файл.pyx в Cython:

import numpy as np
cimport cython
cimport numpy as np

def get_energy(np.ndarray[np.float64_t, ndim=2] system, int i,int j,unsigned int m): 
  cdef int up
  cdef int down
  cdef int left
  cdef int right
  cdef np.float64_t total
  up=i-1;  down=i+1;  left=j-1;  right=j+1
  if up<0: total=system[m,j]
  else: total=system[up,j]
  if down>m: total+=system[0,j]
  else: total+=system[down,j]
  if left<0: total+=system[i,m]
  else: total+=system[i,left]
  if right>m: total+=system[i,0]
  else: total+=system[i,right]
  return 2*system[i,j]*total

Пожалуйста, прокомментируйте, если мне нужно дать дополнительную информацию.

0 ответов

Другие вопросы по тегам