Flax-реализация padding_idx из torch.nn.embedding
Я переписывал некоторые из своих моделей pytorch в jax/flax и столкнулся с проблемой преобразования torch.nn.Embedding в flax.linen.Embed.
Прямого перевода слова pytorch не существует.padding_idx
. Ключевое слово по существу равно 0 встраиваниям (т. е. без градиентов) для случаев, когдаinput==padding_idx
. Работа с (динамической) маскировкой в экосистеме jax/flax довольно сложна, и поэтому я не смог найти способ функционально перевести мою простую модель встраивания с pytorch на flax. Будем очень признательны за любое понимание того, как функционально выполнить аналогичную операцию padding_idx во льне.
Спасибо за любую помощь в вопросе.