Как использовать JAX и autorgrad для обратного распространения ошибки, созданного с помощью сокращения?

Некоторое время назад я построил numpyоснованная на машинном обучении "библиотека" как школьное домашнее задание. Он был основан исключительно наnumpy, но теперь я хочу перевести его на JAX. У меня возникли проблемы с настройкой процесса обратного распространения ошибки.

Эта библиотека издевается над pytorch, поэтому каждый слой представляет собой класс с forward а также backwardметоды. Вnumpy, мой линейный слой

class Linear:
    def __init__(self, in_features: int, out_features: int):
        self.W = np.random.randn(in_features, out_features)
        self.b = np.random.randn(out_features)

    def forward(self, x: Tensor) -> Tensor:
        self.input = x
        return x @ self.W + self.b

    def backward(self, grad: Tensor) -> Tensor:
        # in_feat by batch_size @ batch_size by out_feat
        self.dydw = self.input.T @ grad
        # we sum across batches and get shape (out_features)
        self.dydb = grad.sum(axis=0)
        # output must be of shape (batch_size, out_features)
        return grad @ self.W.T

Теперь у меня проблемы с переводом на JAX. Я попытался определить__matmul__ метод

def __matmul__(
        self, weight: Tensor, input_: Tensor, bias: Tensor
    ) -> Tensor:
        return jnp.dot(weight, input_.T) + bias

а затем используйте jax.grad

def __grad__(self,) -> Tuple[Callable, Callable, Callable]:
        return (
            grad(self.__matmul__, argnums=0),
            grad(self.__matmul__, argnums=1),
            grad(self.__matmul__, argnums=2),
        )

это не работает. Если я просто используюgrad на выходе relu(linear(x))результат кажется неверным. Каким будет правильный способ использования автограда JAX здесь?

0 ответов

Другие вопросы по тегам