Ошибка в коде Pytorch DeepMask
В настоящее время я пытаюсь реализовать Deepmask ( ссылка на документ FAIR) с помощью Pytorch, поэтому я уже определил функцию потери потерь в соединении, а также доступные для изучения параметры модели и прямой проход.
Я работал на этапе обучения, и поскольку в документе говорится, что обучение должно проводиться альтернативным способом обратного распространения по двум ветвям, я написал код для этого же.
Но есть некоторая проблема с обучением, я пытался обучить модель с поддельным набором данных (случайно сгенерированным набором данных), для мини-пакетов, отличных от первой мини-партии, потеря модели оказывается нан.
Что может быть причиной этой потери нан?