Потребление памяти Flax при обратном проходе
недавно я построил свою первую модель во льне. Прямой проход работал нормально, но я столкнулся с ошибками OOM во время обратного прохода.
Изначально я разделил свою модель на несколько небольших классов, каждый из которых реализован как собственная модель льна, наследуемая от linen.module. Для отладки я объединил все эти части в одну модель льна. Это резко уменьшило объем памяти при обратном проходе.
Может ли кто-нибудь объяснить, является ли это ожидаемым поведением, и если да, то каков источник накладных расходов памяти при обратном проходе при использовании нескольких небольших классов вместо одного большого?
Заранее спасибо и с наилучшими пожеланиями.