Как изменить размер пакета для нейронной сети в JAX

Использование льна для создания сети:

      def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

Входные размеры[1, 28, 28, 1], в моем индивидуальном обучении мне нужно передать ввод с различными формами пакетов, такими как[5, 28, 28, 1]. Как я могу реализовать это для льна? В JAX вы можете использоватьvmapно тут не уверен.

0 ответов

Другие вопросы по тегам