Flax-реализация padding_idx из torch.nn.embedding

Я переписывал некоторые из своих моделей pytorch в jax/flax и столкнулся с проблемой преобразования torch.nn.Embedding в flax.linen.Embed.

Прямого перевода слова pytorch не существует.padding_idx. Ключевое слово по существу равно 0 встраиваниям (т. е. без градиентов) для случаев, когдаinput==padding_idx. Работа с (динамической) маскировкой в ​​экосистеме jax/flax довольно сложна, и поэтому я не смог найти способ функционально перевести мою простую модель встраивания с pytorch на flax. Будем очень признательны за любое понимание того, как функционально выполнить аналогичную операцию padding_idx во льне.

Спасибо за любую помощь в вопросе.

0 ответов

Другие вопросы по тегам