vmap над списком в jax

Используя jax, я пытаюсь вычислить градиенты для каждого образца, обработать их, а затем привести их в нормальную форму, чтобы вычислить нормальное обновление параметров. Мой рабочий код выглядит так

differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)

# some code

gradients_summed_over_samples = []
    for layer in gradients:
        (dw, db) = layer
        (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
        gradients_summed_over_samples.append((dw, db))

где gradients имеет форму list(tuple(DeviceArray(...), DeviceArray(...)), ...).

Теперь я попытался переписать цикл как vmap (не уверен, принесет ли он в итоге ускорение)

def sum_samples(layer):
    (dw, db) = layer
    (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))

vmap(sum_samples)(gradients)

но sum_samples вызывается только один раз, а не для каждого элемента в списке.

Проблема в списке или я еще что-то не так понимаю?

1 ответ

jax.vmapбудет отображаться только на входы массива jax, а не на входы, которые являются списками массивов или кортежей. Кроме того, функции vmapped не могут изменять входные данные на месте; функции должны возвращать значение, и это возвращаемое значение будет складываться с другими возвращаемыми значениями для построения вывода

Например, вы можете изменить определенную функцию и использовать ее следующим образом:

import jax.numpy as np
from jax import random

def sum_samples(layer):
    (dw, db) = layer
    (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
    return np.array([dw, db])

key = random.PRNGKey(1701)
data = random.uniform(key, (10, 2, 20))

result = vmap(sum_samples)(data)
print(result.shape)
# (10, 2)

Боковое примечание: если вы используете этот подход, приведенную выше функцию vmapped можно более кратко выразить как:

def sum_samples(layer):
    return layer.sum(1)
Другие вопросы по тегам