Как создать код, похожий на Pytorch, в Jax Flax
Я пытаюсь построить NN с выпадающим слоем, чтобы избежать переоснащения. Но я столкнулся с некоторыми проблемами, когда писал его на Jax Flax.
Вот оригинальная модель, которую я построил в Pytorch:
class MLPModel(nn.Module):
def __init__(self, layer, dp_rate=0.1):
super().__init__()
layers = []
for idx in range(len(layer) - 1):
layers += [
nn.Linear(layer[idx], layer[idx + 1]),
nn.ReLU(inplace=True),
nn.Dropout(dp_rate)
]
self.layers = nn.Sequential(*layers)
def forward(self, x, *args, **kwargs):
return self.layers(x)
Этот код работает хорошо. Но когда я адаптировал его под Flax, что-то пошло не так:
class CNN(nn.Module):
hidden_size: Sequence[int]
dp_rate: float
training: bool
def setup(self):
layers = []
for idx in range(len(self.hidden_size)):
layers.append(nn.Dense(self.hidden_size[idx]))
self.linear_layers = layers
@nn.compact
def __call__(self, x):
for layer in self.linear_layers:
x = layer(x)
x = nn.relu(x)
x = nn.Dropout(self.dp_rate)(x, deterministic=not self.training)
x = nn.Dense(self.hidden_size[-1])(x)
x = nn.log_softmax(x)
return x
Сообщение об ошибке: «Несовместимые формы для трансляции: ((1, 1, 128, 10), (128, 28, 28, 10))» (в качестве набора данных я использовал MNIST). И это происходит в:
@jax.jit
def train_step(state, imgs, gt_labels, key):
def loss_fn(params):
logits = CNN(training=True, hidden_size = [50,50,10], dp_rate = 0.1).apply(params, imgs, rngs={'dropout': random.PRNGKey(2)})
one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)
loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads) # this is the whole update now! concise!
metrics = compute_metrics(logits=logits, gt_labels=gt_labels) # duplicating loss calculation but it's a bit cleaner
return state, metrics
Размер (1, 1, 128, 10), я думаю, должен быть прогнозом, а (128, 28, 28, 10) должен быть размером ввода. Я следовал руководству в официальной документации (почти те же коды), и я немного запутался в этой ошибке.
Я поделился ссылкой на документ здесь: https://colab.research.google.com/drive/1o6_FgW7AO2XvhuM9NGfLMFOWBFOgbF6G?usp=sharing