Вычисление векторного произведения Гессе выхода льняной нейронной сети относительно входных данных
Я пытаюсь получить вторую производную от вывода относительно ввода нейронной сети, построенной с использованием Flax. Сеть устроена следующим образом:
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax import optim
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.tanh(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([20, 20, 20, 20, 20, 1])
batch = jnp.ones((32, 3)) #Dummy input to Initialize the NN
params = model.init(jax.random.PRNGKey(0), batch)
X = jnp.ones((32, 3))
output = model.apply(params, X)
Я могу получить единственную производную, используя vmap вместо grad :
@jit
def u_function(params, X):
u = model.apply(params, X)
return jnp.squeeze(u)
grad_fn = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_X = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
Однако, когда я пытаюсь сделать это снова, чтобы получить вторую производную:
u_X_func = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_XX_func = vmap(grad(u_X_func, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
Я получаю следующую ошибку:
[/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py](https://localhost:8080/#) in __call__(self, inputs)
186 kernel = self.param('kernel',
187 self.kernel_init,
--> 188 (jnp.shape(inputs)[-1], self.features),
189 self.param_dtype)
190 if self.use_bias:
IndexError: tuple index out of range
Я попытался использовать определение hvp из поваренной книги autodiff, но поскольку параметры были входными данными для функции, я просто не знал, что делать дальше.
Любая помощь в этом будет действительно ценной.
1 ответ
Дело в том, что ваш
u_function
отображает вектор длины 3 в скаляр. Первая производная от этого — вектор длины 3, а вторая производная от него — матрица Гессе 3x3, которую вы не можете вычислить с помощью
jax.grad
, который предназначен только для функций скалярного вывода. К счастью, JAX предоставляет
jax.hessian
Преобразуйте, чтобы вычислить эти общие вторые производные:
u_XX = vmap(hessian(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
print(u_XX.shape)
# (32, 3, 3)