Преобразование модели льна в Pytorch
У меня есть несколько классификаторов изображений во Flax. Для одной из моделей я сохранил состояние, а для двух других я сохранил параметры в виде замороженного словаря с.flax
расширение. Мой вопрос: как я могу преобразовать целые модели в Pytorch и использовать эти веса, чтобы иметь такую же идентичную модель в Pytorch?
Например, одна из моделей такова:
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x, training = True):
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.Dropout(0.5, deterministic= not training)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
Другой - ResNet18.
Спасибо.