Как создать код, похожий на Pytorch, в Jax Flax

Я пытаюсь построить NN с выпадающим слоем, чтобы избежать переоснащения. Но я столкнулся с некоторыми проблемами, когда писал его на Jax Flax.

Вот оригинальная модель, которую я построил в Pytorch:

      class MLPModel(nn.Module):

def __init__(self, layer, dp_rate=0.1):
    super().__init__()
    layers = []
    for idx in range(len(layer) - 1):
        layers += [
            nn.Linear(layer[idx], layer[idx + 1]),
            nn.ReLU(inplace=True),
            nn.Dropout(dp_rate)
        ]
    self.layers = nn.Sequential(*layers)

def forward(self, x, *args, **kwargs):
    return self.layers(x)

Этот код работает хорошо. Но когда я адаптировал его под Flax, что-то пошло не так:

      class CNN(nn.Module):
hidden_size: Sequence[int]
dp_rate: float
training: bool

def setup(self):
    layers = []
    for idx in range(len(self.hidden_size)):
        layers.append(nn.Dense(self.hidden_size[idx]))
    self.linear_layers = layers
@nn.compact
def __call__(self, x):
    for layer in self.linear_layers:
        x = layer(x)
        x = nn.relu(x)
        x = nn.Dropout(self.dp_rate)(x, deterministic=not self.training)
    x = nn.Dense(self.hidden_size[-1])(x)    
    x = nn.log_softmax(x)
    return x

Сообщение об ошибке: «Несовместимые формы для трансляции: ((1, 1, 128, 10), (128, 28, 28, 10))» (в качестве набора данных я использовал MNIST). И это происходит в:

      @jax.jit
def train_step(state, imgs, gt_labels, key):
    def loss_fn(params):
        logits = CNN(training=True, hidden_size = [50,50,10], dp_rate = 0.1).apply(params, imgs, rngs={'dropout': random.PRNGKey(2)})
        one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)
        loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)  # this is the whole update now! concise!
    metrics = compute_metrics(logits=logits, gt_labels=gt_labels)  # duplicating loss calculation but it's a bit cleaner
    return state, metrics

Размер (1, 1, 128, 10), я думаю, должен быть прогнозом, а (128, 28, 28, 10) должен быть размером ввода. Я следовал руководству в официальной документации (почти те же коды), и я немного запутался в этой ошибке.

Я поделился ссылкой на документ здесь: https://colab.research.google.com/drive/1o6_FgW7AO2XvhuM9NGfLMFOWBFOgbF6G?usp=sharing

0 ответов

Другие вопросы по тегам