Как изменить размер пакета для нейронной сети в JAX
Использование льна для создания сети:
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
Входные размеры[1, 28, 28, 1]
, в моем индивидуальном обучении мне нужно передать ввод с различными формами пакетов, такими как[5, 28, 28, 1]
. Как я могу реализовать это для льна? В JAX вы можете использоватьvmap
но тут не уверен.