Восстановление правильности const для прямого прохода NN
Я пытаюсь реализовать простую нейронную сеть, используя pytorch/libtorch. Следующий пример адаптирован из учебника по внешнему интерфейсу libtorch cpp.
#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
DeepQImpl(size_t N)
: linear1(2,5),
linear2(5,3) {}
torch::Tensor forward(torch::Tensor x) const {
x = torch::tanh(linear1(x));
x = linear2(x);
return x;
}
torch::nn::Linear linear1, linear2;
};
TORCH_MODULE(DeepQ);
Обратите внимание, что функция forward
объявлен const
, Код, который я пишу, требует, чтобы оценка NN была константной функцией, что мне кажется разумным. Этот код не компилируется, хотя. Компилятор кидает
ошибка: нет соответствия для вызова '(const torch::nn::Linear) (at::Tensor&)'
х = линейный 1(х);
Я нашел способ обойти это, определив слои, которые будут mutable
:
#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
/* all the code */
mutable torch::nn:Linear linear1, linear2;
};
Итак, мой вопрос
- Почему нанесение слоя на тензор не
const
- Использует
mutable
способ исправить это и это безопасно?
Моя интуиция заключается в том, что на прямом проходе слои собираются в структуру, которая может использоваться для обратного распространения, что требует некоторой операции записи. Если это правда, возникает вопрос, как собрать слои в первый (не const
) шаг, а затем оценить структуру в секунду (const
) шаг.