Потребление памяти Flax при обратном проходе

недавно я построил свою первую модель во льне. Прямой проход работал нормально, но я столкнулся с ошибками OOM во время обратного прохода.

Изначально я разделил свою модель на несколько небольших классов, каждый из которых реализован как собственная модель льна, наследуемая от linen.module. Для отладки я объединил все эти части в одну модель льна. Это резко уменьшило объем памяти при обратном проходе.

Может ли кто-нибудь объяснить, является ли это ожидаемым поведением, и если да, то каков источник накладных расходов памяти при обратном проходе при использовании нескольких небольших классов вместо одного большого?

Заранее спасибо и с наилучшими пожеланиями.

0 ответов

Другие вопросы по тегам