vmap над списком в jax
Используя jax, я пытаюсь вычислить градиенты для каждого образца, обработать их, а затем привести их в нормальную форму, чтобы вычислить нормальное обновление параметров. Мой рабочий код выглядит так
differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)
# some code
gradients_summed_over_samples = []
for layer in gradients:
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
gradients_summed_over_samples.append((dw, db))
где gradients
имеет форму list(tuple(DeviceArray(...), DeviceArray(...)), ...)
.
Теперь я попытался переписать цикл как vmap (не уверен, принесет ли он в итоге ускорение)
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
vmap(sum_samples)(gradients)
но sum_samples
вызывается только один раз, а не для каждого элемента в списке.
Проблема в списке или я еще что-то не так понимаю?
1 ответ
jax.vmap
будет отображаться только на входы массива jax, а не на входы, которые являются списками массивов или кортежей. Кроме того, функции vmapped не могут изменять входные данные на месте; функции должны возвращать значение, и это возвращаемое значение будет складываться с другими возвращаемыми значениями для построения вывода
Например, вы можете изменить определенную функцию и использовать ее следующим образом:
import jax.numpy as np
from jax import random
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
return np.array([dw, db])
key = random.PRNGKey(1701)
data = random.uniform(key, (10, 2, 20))
result = vmap(sum_samples)(data)
print(result.shape)
# (10, 2)
Боковое примечание: если вы используете этот подход, приведенную выше функцию vmapped можно более кратко выразить как:
def sum_samples(layer):
return layer.sum(1)