Восстановление правильности const для прямого прохода NN

Я пытаюсь реализовать простую нейронную сеть, используя pytorch/libtorch. Следующий пример адаптирован из учебника по внешнему интерфейсу libtorch cpp.

#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
    DeepQImpl(size_t N)
        : linear1(2,5),
          linear2(5,3) {}
    torch::Tensor forward(torch::Tensor x) const {
        x = torch::tanh(linear1(x));
        x = linear2(x);
        return x;
    }
    torch::nn::Linear linear1, linear2;
};
TORCH_MODULE(DeepQ);

Обратите внимание, что функция forward объявлен const, Код, который я пишу, требует, чтобы оценка NN была константной функцией, что мне кажется разумным. Этот код не компилируется, хотя. Компилятор кидает

ошибка: нет соответствия для вызова '(const torch::nn::Linear) (at::Tensor&)'
х = линейный 1(х);

Я нашел способ обойти это, определив слои, которые будут mutable:

#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
    /* all the code */
    mutable torch::nn:Linear linear1, linear2;
};

Итак, мой вопрос

  1. Почему нанесение слоя на тензор не const
  2. Использует mutable способ исправить это и это безопасно?

Моя интуиция заключается в том, что на прямом проходе слои собираются в структуру, которая может использоваться для обратного распространения, что требует некоторой операции записи. Если это правда, возникает вопрос, как собрать слои в первый (не const) шаг, а затем оценить структуру в секунду (const) шаг.

0 ответов

Другие вопросы по тегам