Ошибка в коде Pytorch DeepMask

В настоящее время я пытаюсь реализовать Deepmask ( ссылка на документ FAIR) с помощью Pytorch, поэтому я уже определил функцию потери потерь в соединении, а также доступные для изучения параметры модели и прямой проход.

Я работал на этапе обучения, и поскольку в документе говорится, что обучение должно проводиться альтернативным способом обратного распространения по двум ветвям, я написал код для этого же.

Но есть некоторая проблема с обучением, я пытался обучить модель с поддельным набором данных (случайно сгенерированным набором данных), для мини-пакетов, отличных от первой мини-партии, потеря модели оказывается нан.

Что может быть причиной этой потери нан?

Ссылка на текущую версию моего кода

0 ответов

Другие вопросы по тегам