numba медленнее для numpy.bitwise_and для логических массивов

Я пытаюсь Numba в этом фрагменте кода

from numba import jit
import numpy as np
from time import time
db  = np.array(np.random.randint(2, size=(400e3, 4)), dtype=bool)
out = np.zeros((int(400e3), 1))

@jit()
def check_mask(db, out, mask=[1, 0, 1]):
    for idx, line in enumerate(db):
        target, vector = line[0], line[1:]
        if (mask == np.bitwise_and(mask, vector)).all():
            if target == 1:
                out[idx] = 1
    return out

st = time()
res = check_mask(db, out, [1, 0, 1])
print 'with jit: {:.4} sec'.format(time() - st)

С декоратором numba @jit() этот код работает медленнее!

  • без джита: 3,16 сек
  • с джитом: 3,81 сек

просто чтобы лучше понять цель этого кода:

db = np.array([           # out value for mask = [1, 0, 1]
    # target,  vector     #
      [1,      1, 0, 1],  # 1
      [0,      1, 1, 1],  # 0 (fit to mask but target == 0)
      [0,      0, 1, 0],  # 0
      [1,      1, 0, 1],  # 1
      [0,      1, 1, 0],  # 0
      [1,      0, 0, 0],  # 0
      ])

3 ответа

Решение

Numba имеет два режима компиляции для jit: режим nopython и режим объекта. Режим Nopython (по умолчанию) поддерживает только ограниченный набор функций Python и Numpy, обратитесь к документации для вашей версии. Если функция jited содержит неподдерживаемый код, Numba вынуждена переключаться в режим объекта, который намного, намного медленнее.

Я не уверен, что режим objcet должен дать ускорение по сравнению с чистым Python, но вы все равно всегда захотите использовать режим nopython. Чтобы убедиться, что используется режим nopython, укажите nopython=True и придерживаться очень простого кода (практическое правило: выписать все циклы и использовать только скаляры и массивы Numpy):

@jit(nopython=True)
def check_mask_2(db, out, mask=np.array([1, 0, 1])):
    for idx in range(db.shape[0]):
        if db[idx,0] != 1:
            continue
        check = 1
        for j in range(db.shape[1]):
            if mask[j] and not db[idx,j+1]:
                check = 0
                break
        out[idx] = check
    return out

Явное написание внутреннего цикла также имеет то преимущество, что мы можем выйти из него, как только условие не выполнится.

Тайминги:

%time _ = check_mask(db, out, np.array([1, 0, 1]))
# Wall time: 1.91 s
%time _ = check_mask_2(db, out, np.array([1, 0, 1]))
# Wall time: 310 ms  # slow because of compilation
%time _ = check_mask_2(db, out, np.array([1, 0, 1]))
# Wall time: 3 ms

Кстати, функция также легко векторизована с помощью Numpy, что дает приличную скорость:

def check_mask_vectorized(db, mask=[1, 0, 1]):
    check = (db[:,1:] == mask).all(axis=1)
    out = (db[:,0] == 1) & check
    return out

%time _ = check_mask_vectorized(db, [1, 0, 1])
# Wall time: 14 ms

Кроме того, вы можете попробовать Pythran(отказ от ответственности: я разработчик Pythran).

С одной аннотацией компилируется следующий код

#pythran export check_mask(bool[][], bool[])

import numpy as np
def check_mask(db, out, mask=[1, 0, 1]):
    for idx, line in enumerate(db):
        target, vector = line[0], line[1:]
        if (mask == np.bitwise_and(mask, vector)).all():
            if target == 1:
                out[idx] = 1
    return out

с призывом к pythran check_call.py,

И согласно timeitполученный модуль работает довольно быстро:

python -m timeit -s 'n=1e4; import numpy as np; db  = np.array(np.random.randint(2, size=(n, 4)), dtype=bool); out = np.zeros(int(n), dtype=bool); from eq import check_mask' 'check_mask(db, out)'

говорит мне, что версия CPython работает в 136ms в то время как Pythran-скомпилированная версия работает в 450us,

Я бы порекомендовал удалить пустой вызов array_equal из внутреннего цикла. Numba не обязательно достаточно умна, чтобы превратить это в часть встроенного C; и если он не сможет заменить этот вызов, доминирующая стоимость вашей функции останется сопоставимой, что объясняет ваш результат.

В то время как numba может рассуждать о значительном числе numpy-конструкций, это только код в стиле C, работающий с numpy-массивами, который можно рассчитывать на ускорение.

Другие вопросы по тегам