Преобразование модели льна в Pytorch

У меня есть несколько классификаторов изображений во Flax. Для одной из моделей я сохранил состояние, а для двух других я сохранил параметры в виде замороженного словаря с.flaxрасширение. Мой вопрос: как я могу преобразовать целые модели в Pytorch и использовать эти веса, чтобы иметь такую ​​​​же идентичную модель в Pytorch?

Например, одна из моделей такова:

      class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x, training = True):
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.Dropout(0.5, deterministic= not training)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

Другой - ResNet18.

Спасибо.

0 ответов

Другие вопросы по тегам