Описание тега flax
1
ответ
Flax намного медленнее, чем чистый Jax для нейронных сетей?
для проекта я пытаюсь написать очень простой пример MLP, но я заметил, что реализация во льне примерно в 20 раз медленнее, чем чистая реализация jax. Что я здесь делаю неправильно? import time import jax.numpy as np from jax import random, jit, vmap…
05 мар '22 в 00:40
0
ответов
Pickle меняет тип в jax
У меня есть класс данных структуры льна, содержащий массив jax numpy. Когда я собираю дамп этого объекта и загружаю его снова, массив больше не является массивом jax numpy и преобразуется в массив numpy, вот код для его воспроизведения: import flax …
10 май '22 в 17:58
1
ответ
Вычисление векторного произведения Гессе выхода льняной нейронной сети относительно входных данных
Я пытаюсь получить вторую производную от вывода относительно ввода нейронной сети, построенной с использованием Flax. Сеть устроена следующим образом: import numpy as np import jax import jax.numpy as jnp import flax.linen as nn import optax from fl…
02 июн '22 в 16:17
0
ответов
Потребление памяти Flax при обратном проходе
недавно я построил свою первую модель во льне. Прямой проход работал нормально, но я столкнулся с ошибками OOM во время обратного прохода. Изначально я разделил свою модель на несколько небольших классов, каждый из которых реализован как собственная…
09 ноя '22 в 14:01
1
ответ
AttributeError: модуль «лен» не имеет атрибута «оптимальный»
Мой код выглядит следующим образом: !pip install flax init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params'] print(f'Model parameters: {n_params(init_params):,}') optim = flax.optim.Adam(lr=1e-4).create(init_params) Однак…
25 авг '22 в 14:00
1
ответ
Почему JAX выдает нефильтрованную трассировку стека?
Мне нужно перейти на шаг поезда, но когда я это делаю, я получаю эту ошибку import jax_resnet import jax import jax.numpy as jnp from flax import linen as nn import tensorflow_datasets as tfds from flax.training import train_state import optax impor…
26 сен '22 в 22:58
1
ответ
Jax - vmap по пакету классов данных
В JAX я ищу функцию vmap для списка классов данных фиксированной длины, например: import jax, chex from flax import struct @struct.dataclass class EnvParams: max_steps: int = 500 random_respawn: bool = False def foo(params: EnvParams): ... param_lis…
18 сен '22 в 17:41
1
ответ
Получение неправильного вывода из вызова инициализации льняной модели
Я пытаюсь создать простую нейронную сеть с использованием льна, как показано ниже. Однакоparamsзамороженный дикт я получаю в качестве выводаmodel.initпуст вместо того, чтобы иметь параметры нейронной сети. Такжеtype(predictions)являетсяflax.linen.co…
04 дек '22 в 12:20
0
ответов
AttributeError: модуль «лен» не имеет атрибута «nn»
Я пытаюсь запустить RegNeRF, для чего требуется лен. При установке последней версии flax==0.6.0 я получил сообщение об ошибке, указывающее, что flax не имеет атрибута optim. В этом ответе предлагается понизить версию льна до 0.5.1. При этом теперь я…
02 окт '22 в 09:32
0
ответов
Можете ли вы обновить параметры модуля из nn.compact этого модуля? (самоизменяющиеся сети)
Я новичок в льне, и мне было интересно, как правильно получить такое поведение: param = f.init(key,x) new_param, y = f.apply(param,x) Где f — экземпляр nn.module. Где f может пройти через несколько операций, чтобы получить new_param, и что эти опера…
17 июн '22 в 12:38
0
ответов
Исчезающие параметры в MAML JAX (метаобучение)
Я работаю над реализацией MAML (см. https://arxiv.org/pdf/1703.03400.pdf ) в Jax. При обучении распределению простых задач линейной регрессии кажется, что все работает нормально (требуется некоторое время, чтобы сходиться, но в конечном итоге работа…
17 окт '22 в 02:06
0
ответов
Flax-реализация padding_idx из torch.nn.embedding
Я переписывал некоторые из своих моделей pytorch в jax/flax и столкнулся с проблемой преобразования torch.nn.Embedding в flax.linen.Embed. Прямого перевода слова pytorch не существует.padding_idx. Ключевое слово по существу равно 0 встраиваниям (т. …
18 окт '22 в 16:18
0
ответов
Как преобразовать сохраненные контрольные точки в модель?
https://github.com/google-research/scenic/tree/main/scenic/projects/mbt Я пытаюсь использовать предварительно обученную модель, представленную в git, которая в основном представляет собой контрольную точку Flax. Я хочу преобразовать его обратно в мо…
26 окт '22 в 12:45
0
ответов
Pytorch-эквивалент `register_buffer` во льне/jax
Я ищу способ написать эквивалент следующего модуля Pytorch во Flax, но не нашел способа сделать это. Важно то, что константа должна быть загружаемой и сохраняемой на контрольной точке. class SillyModule(nn.Module): def __init__(self, ): super().__in…
07 авг '22 в 09:58
0
ответов
Как создать код, похожий на Pytorch, в Jax Flax
Я пытаюсь построить NN с выпадающим слоем, чтобы избежать переоснащения. Но я столкнулся с некоторыми проблемами, когда писал его на Jax Flax. Вот оригинальная модель, которую я построил в Pytorch: class MLPModel(nn.Module): def __init__(self, layer…
09 авг '22 в 19:57
0
ответов
Как изменить размер пакета для нейронной сети в JAX
Использование льна для создания сети: def create_train_state(rng, learning_rate, momentum): """Creates initial `TrainState`.""" cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(learning_rate, momentum) return tra…
27 сен '22 в 19:44
1
ответ
Я пытаюсь назначить объект JAX Tracer массиву NumPy, для которого требуются конкретные значения - пожалуйста, обойдите это
Я новичок в Джексе. Я реализую вариационный автоэнкодер (VAE) с использованием Jax и Flax. Во время обучения я сэмплирую скрытый код (из дистрибутива, выведенного энкодером, который я реализую с помощью композиций модулей flax.linen.nn). Важно отмет…
23 авг '22 в 20:24
1
ответ
Как я могу инициализировать скрытое состояние (перенос) GRUCell (льняное полотно) в качестве обучаемого параметра (например, с помощью model.init)
Я создаю модель GRU в Jax с помощью Flax и инициализирую параметры модели с помощью model.init следующим образом: import jax.numpy as np from jax import random import flax.linen as nn from jax.nn import initializers class RNN(nn.Module): n_RNN_units…
26 авг '22 в 23:19
0
ответов
Преобразование модели льна в Pytorch
У меня есть несколько классификаторов изображений во Flax. Для одной из моделей я сохранил состояние, а для двух других я сохранил параметры в виде замороженного словаря с.flaxрасширение. Мой вопрос: как я могу преобразовать целые модели в Pytorch и…
01 фев '23 в 02:16
0
ответов
есть ли способ отслеживать грады с помощью метода self.put_variable во льне?
Я хотел бы отслеживать оценки через переменную self.put_variable. Есть ли способ сделать это возможным? Или другой способ обновить параметр, предоставленный отслеживаемому модулю? import jax from jax import numpy as jnp from jax import grad,random,j…
21 июн '22 в 18:41