Pytorch-lightning - не работает загрузка модели с КПП

Я обучил vanilla vae, который я модифицировал из этого репозитория. Когда я пытаюсь использовать обученную модель, я не могу загрузить веса с помощьюload_from_checkpoint. Кажется, есть несоответствие между моим объектом контрольной точки и моимlightningModule объект.

Я поставил эксперимент (VAEXperiment) с помощью pytorch-lightning LightningModule. Я пытаюсь загрузить веса в сеть с помощью:

#building a new model
model = VanillaVAE(**config['model_params'])
model.build_layers()

#loading the weights
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_from_checkpoint(path_to_checkpoint, config['exp_params'])

Я также пробовал:

checkpoint = torch.load(path_to_checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])

Но я получаю ошибкуUnexpected key(s) in state_dict: "model.encoder.0.0.weight", "model.encoder.0.0.bias"...

Я также следил за проблемой на https://github.com/PyTorchLightning/pytorch-lightning/issues/924https://github.com/PyTorchLightning/pytorch-lightning/issues/2798

Почему я получаю эту ошибку? Это из-за модулей кодировщика и декодера в моей модели? Судя по журналу проблем на git, кажется, что ошибка устранена. Что я делаю неправильно?

1 ответ

Решение

Размещение ответа из комментариев:

experiment.load_state_dict(checkpoint['state_dict'])
Другие вопросы по тегам