Как использовать JAX и autorgrad для обратного распространения ошибки, созданного с помощью сокращения?
Некоторое время назад я построил numpy
основанная на машинном обучении "библиотека" как школьное домашнее задание. Он был основан исключительно наnumpy
, но теперь я хочу перевести его на JAX. У меня возникли проблемы с настройкой процесса обратного распространения ошибки.
Эта библиотека издевается над pytorch, поэтому каждый слой представляет собой класс с forward
а также backward
методы. Вnumpy
, мой линейный слой
class Linear:
def __init__(self, in_features: int, out_features: int):
self.W = np.random.randn(in_features, out_features)
self.b = np.random.randn(out_features)
def forward(self, x: Tensor) -> Tensor:
self.input = x
return x @ self.W + self.b
def backward(self, grad: Tensor) -> Tensor:
# in_feat by batch_size @ batch_size by out_feat
self.dydw = self.input.T @ grad
# we sum across batches and get shape (out_features)
self.dydb = grad.sum(axis=0)
# output must be of shape (batch_size, out_features)
return grad @ self.W.T
Теперь у меня проблемы с переводом на JAX. Я попытался определить__matmul__
метод
def __matmul__(
self, weight: Tensor, input_: Tensor, bias: Tensor
) -> Tensor:
return jnp.dot(weight, input_.T) + bias
а затем используйте jax.grad
def __grad__(self,) -> Tuple[Callable, Callable, Callable]:
return (
grad(self.__matmul__, argnums=0),
grad(self.__matmul__, argnums=1),
grad(self.__matmul__, argnums=2),
)
это не работает. Если я просто используюgrad
на выходе relu(linear(x))
результат кажется неверным. Каким будет правильный способ использования автограда JAX здесь?