Вычисление векторного произведения Гессе выхода льняной нейронной сети относительно входных данных

Я пытаюсь получить вторую производную от вывода относительно ввода нейронной сети, построенной с использованием Flax. Сеть устроена следующим образом:

      import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax import optim

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.tanh(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([20, 20, 20, 20, 20, 1])
batch = jnp.ones((32, 3)) #Dummy input to Initialize the NN
params = model.init(jax.random.PRNGKey(0), batch)
X =  jnp.ones((32, 3))
output = model.apply(params, X)

Я могу получить единственную производную, используя vmap вместо grad :

      @jit
def u_function(params, X):
  u = model.apply(params, X)
  return jnp.squeeze(u)

grad_fn = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))

u_X = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)

Однако, когда я пытаюсь сделать это снова, чтобы получить вторую производную:

      u_X_func = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_XX_func = vmap(grad(u_X_func, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)

Я получаю следующую ошибку:

      [/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py](https://localhost:8080/#) in __call__(self, inputs)
    186     kernel = self.param('kernel',
    187                         self.kernel_init,
--> 188                         (jnp.shape(inputs)[-1], self.features),
    189                         self.param_dtype)
    190     if self.use_bias:

IndexError: tuple index out of range

Я попытался использовать определение hvp из поваренной книги autodiff, но поскольку параметры были входными данными для функции, я просто не знал, что делать дальше.

Любая помощь в этом будет действительно ценной.

1 ответ

Дело в том, что ваш u_functionотображает вектор длины 3 в скаляр. Первая производная от этого — вектор длины 3, а вторая производная от него — матрица Гессе 3x3, которую вы не можете вычислить с помощью jax.grad, который предназначен только для функций скалярного вывода. К счастью, JAX предоставляет jax.hessianПреобразуйте, чтобы вычислить эти общие вторые производные:

      u_XX = vmap(hessian(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
print(u_XX.shape)
# (32, 3, 3)
Другие вопросы по тегам