jax vmap: обеспечить правильную форму
Я использую vmap для векторизации частей моего кода. Вот минимальный пример до векторизации:
dim = 2
def sum(x):
a = np.ones((dim,))
return np.dot(x, a)
num_samples = 100
samples = np.ones((num_samples, dim))
sum(samples[0]) # 2
с vmap:
sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2
Но после векторизации это может пойти не так:
sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1
Здесь происходит то, что samples[0]
имеет форму (2,)
. Вызов векторизованной функции разбивает свой входной аргумент по первой оси и, следовательно, получает 2 массива формы(1,)
. Из-за трансляции сa
, результат будет иметь форму (2,)
снова и складывается в (2,2)
массив.
Мне это кажется опасным. Код выглядит нормально, и результирующий вывод будет легко использован некоторыми другими правилами широковещательной передачи, которые скрывают его нарушенную форму.
Можно ли добиться правильной формы?
1 ответ
"Это кажется мне опасным. Код выглядит нормально, и результат может быть легко использован некоторыми другими правилами вещания, которые скрывают его нарушенную форму".
Заметить, что vmap
выполняет именно то, что здесь предполагается, то есть векторизует нулевое измерение, а широковещательная передача numpy делает именно то, что должна делать. Проблема, конечно, в том, что пользователь дает массив неправильной формы, посколькуvmap
ожидает векторизованный ввод в нулевом измерении x. Вместо этого пользователь должен написать
sum(samples[0:1])
который сохраняет правильную форму.
Другими словами: если вы собираетесь применить vmap к функции, вы не можете использовать эту функцию точно так же, как если бы вы никогда не применяли vmap. Вам необходимо учесть изменение поведения функции.
"Можно ли добиться правильной формы?"
vmap
сам по себе не имеет возможности принудительно изменять форму ввода. Если вас особенно беспокоит то, что пользователь придает функции неправильную форму, вы можете встроить это в исходную функцию. Например,
def sum(x):
if (x.shape[-1] != dim):
raise Exception()
a = np.ones((dim,))
return np.dot(x, a)
сломается, если вы не придаете ему правильную форму даже после нанесения vmap
.