JAX: время jit функции становится суперлинейной с доступом к памяти функцией
Вот простой пример, который численно интегрирует произведение двух гауссовских PDF-файлов. Один из гауссианов фиксирован, среднее всегда равно 0. Среднее значение другого гауссиана изменяется:
import time
import jax.numpy as np
from jax import jit
from jax.scipy.stats.norm import pdf
# set up evaluation points for numerical integration
integr_resolution = 6400
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution
# integrate with new mean
def integrate(mu_new):
x_new = integr_grid - mu_new
proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)
return total_proba
print('starting jit')
start = time.perf_counter()
integrate = jit(integrate)
integrate(1)
stop = time.perf_counter()
print('took: ', stop - start)
Функция кажется простой, но совершенно не масштабируется. Следующий список содержит пары (значение inter_resolution, время, необходимое для выполнения кода):
- 100 | 0,107 с
- 200 | 0,23 с
- 400 | 0,537 с
- 800 | 1,52 с
- 1600 | 5,2 с
- 3200 | 19 с
- 6400 | 134с
Для справки, функция unjitted, примененная к integr_resolution=6400
занимает 0,02 с.
Я подумал, что это может быть связано с тем, что функция обращается к глобальной переменной. Но перемещение кода для настройки точек интегрирования внутри функции не оказывает заметного влияния на время. Для выполнения следующего кода требуется 5,36 с. Это соответствует записи в таблице с 1600, которая ранее занимала 5,2 секунды:
# integrate with new mean
def integrate(mu_new):
# set up evaluation points for numerical integration
integr_resolution = 1600
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution
x_new = integr_grid - mu_new
proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)
return total_proba
Что здесь происходит?
1 ответ
Я также ответил на это на https://github.com/google/jax/issues/1776, но добавил ответ и здесь.
Это потому, что код использует sum
где он должен использовать np.sum
.
sum
- это встроенный Python, который извлекает каждый элемент последовательности и суммирует их один за другим, используя +
оператор. В результате создается большая развернутая цепочка добавлений, компиляция которой XLA занимает много времени.
Если вы используете np.sum
, то JAX создает единственный оператор сокращения XLA, который компилируется намного быстрее.
И чтобы показать, как я это понял: я использовал jax.make_jaxpr
, который выводит внутреннее представление трассировки JAX функции. Здесь показано:
In [3]: import jax
In [4]: jax.make_jaxpr(integrate)(1)
Out[4]:
{ lambda b c ; ; a.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = sub c d
f = sub e 0.0
g = pow f 2.0
h = div g 1.0
i = add 1.8378770351409912 h
j = neg i
k = div j 2.0
l = exp k
m = mul b l
n = mul m 2.0
o = slice[ start_indices=(0,)
limit_indices=(1,)
strides=(1,)
operand_shape=(100,) ] n
p = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] o
q = add p 0.0
r = slice[ start_indices=(1,)
limit_indices=(2,)
strides=(1,)
operand_shape=(100,) ] n
s = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] r
t = add q s
u = slice[ start_indices=(2,)
limit_indices=(3,)
strides=(1,)
operand_shape=(100,) ] n
v = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] u
w = add t v
x = slice[ start_indices=(3,)
limit_indices=(4,)
strides=(1,)
operand_shape=(100,) ] n
y = reshape[ new_sizes=()
dimensions=None
old_sizes=(1,) ] x
z = add w y
... similarly ...
и тогда становится очевидно, почему это происходит медленно: программа очень большая.
Сравните np.sum
версия:
In [5]: def integrate(mu_new):
...: x_new = integr_grid - mu_new
...:
...: proba_new = pdf(x_new)
...: total_proba = np.sum(proba * proba_new * integration_weight)
...:
...: return total_proba
...:
In [6]: jax.make_jaxpr(integrate)(1)
Out[6]:
{ lambda b c ; ; a.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = sub c d
f = sub e 0.0
g = pow f 2.0
h = div g 1.0
i = add 1.8378770351409912 h
j = neg i
k = div j 2.0
l = exp k
m = mul b l
n = mul m 2.0
o = reduce_sum[ axes=(0,)
input_shape=(100,) ] n
in [o] }
Надеюсь, это поможет!