Pytorch-lightning - не работает загрузка модели с КПП
Я обучил vanilla vae, который я модифицировал из этого репозитория. Когда я пытаюсь использовать обученную модель, я не могу загрузить веса с помощьюload_from_checkpoint
. Кажется, есть несоответствие между моим объектом контрольной точки и моимlightningModule
объект.
Я поставил эксперимент (VAEXperiment
) с помощью pytorch-lightning LightningModule
. Я пытаюсь загрузить веса в сеть с помощью:
#building a new model
model = VanillaVAE(**config['model_params'])
model.build_layers()
#loading the weights
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_from_checkpoint(path_to_checkpoint, config['exp_params'])
Я также пробовал:
checkpoint = torch.load(path_to_checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])
Но я получаю ошибкуUnexpected key(s) in state_dict: "model.encoder.0.0.weight", "model.encoder.0.0.bias"
...
Я также следил за проблемой на https://github.com/PyTorchLightning/pytorch-lightning/issues/924https://github.com/PyTorchLightning/pytorch-lightning/issues/2798
Почему я получаю эту ошибку? Это из-за модулей кодировщика и декодера в моей модели? Судя по журналу проблем на git, кажется, что ошибка устранена. Что я делаю неправильно?
1 ответ
Размещение ответа из комментариев:
experiment.load_state_dict(checkpoint['state_dict'])