Сравнение двух подходов возведения в степень элементов матрицы
У меня есть два подхода к возведению в степень матрицы в
jnp = jax.numpy
. Простой:
jnp.exp(-X/reg)
И с некоторыми дополнительными действиями:
def exp_reg(X, reg):
K = jnp.empty_like(X)
K = jnp.divide(X, -reg)
return jnp.exp(K)
Однако когда я их протестировал:
%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()
Второй подход оказался лучше, несмотря на внешне некоторые дополнительные накладные расходы. Я провел
%timeit
с матрицей размером 2000 х 2000:
7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Почему это может быть так?
1 ответ
Разница здесь в порядке действий.
В
jnp.exp(-X/reg)
, вы отрицаете каждую запись
X
, а затем разделив каждую запись результата на
reg
. Это два прохода по массиву
X
.
в
exp_reg
вы отрицаете
reg
(что предположительно является скалярным значением?), а затем разделив
X
по результату. Это один проход
X
.
Если
X
большой, я бы ожидал, что первый подход будет немного медленнее, чем второй, из-за нескольких проходов
X
.
К счастью, поскольку вы используете JAX, вы можете
jit
скомпилируйте свой код, и в этом случае XLA обычно может оптимизировать аналогичные порядки операций, подобные этим. Действительно, для ваших двух функций компиляция устраняет несоответствие:
from jax import jit
import jax.numpy as jnp
import numpy as np
def exp_reg1(X, reg):
return jnp.exp(-X/reg)
def exp_reg2(X, reg):
K = jnp.divide(X, -reg)
return jnp.exp(K)
X = jnp.array(np.random.rand(1000, 1000))
reg = 2.0
%timeit exp_reg1(X, reg)
# 100 loops, best of 3: 3.17 ms per loop
%timeit exp_reg2(X, reg)
# 100 loops, best of 3: 2.2 ms per loop
# Trigger compilation
jit(exp_reg1)(X, reg)
jit(exp_reg2)(X, reg)
%timeit jit(exp_reg1)(X, reg)
# 1000 loops, best of 3: 1.92 ms per loop
%timeit jit(exp_reg2)(X, reg)
# 100 loops, best of 3: 1.84 ms per loop
(примечание: нет причин предварительно выделять пустой массив
K
перед присвоением результата операции одноименной переменной).