Производный класс Pytorch nn.Module не может быть загружен с помощью импорта модуля в Python
Использование Python 3.6 с Pytorch 1.3.1. Я заметил, что некоторые сохраненные модули nn.Modules не могут быть загружены, когда весь модуль импортируется в другой модуль. В качестве примера, вот шаблон минимального рабочего примера.
#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'
from torch import nn
class NN(nn.Module):##NN network
# Initialisation and other class methods
networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
# Some testing snippets
pass
Когда я запускаю его напрямую в оболочке, весь файл работает нормально. Однако когда я хочу использовать класс и загрузить нейронную сеть в другой файл с помощью этого кода, это не удается.
#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *
Ошибка гласит AttributeError: Can't get attribute 'NN' on <module '__main__'>
В Pytorch загрузка сохраненных переменных или импорт модулей происходит иначе, чем в других распространенных библиотеках Python? Некоторая помощь или указатель на основную причину будут действительно признательны.
1 ответ
Когда вы сохраняете модель с torch.save(model, PATH)
весь объект сериализуется с помощью pickle
, который сохраняет не сам класс, а путь к файлу, содержащему класс, поэтому при загрузке модели требуются точно такие же каталог и файловая структура, чтобы найти правильный класс. При запуске скрипта Python модуль этого файла__main__
, поэтому, если вы хотите загрузить этот модуль, ваш NN
class должен быть определен в скрипте, который вы запускаете.
Это очень негибкий подход, поэтому рекомендуется не сохранять всю модель, а вместо этого просто сохранять словарь состояний, который сохраняет только параметры модели.
# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)
После этого словарь состояний можно загрузить и применить к вашей модели.
from dnn_predict import NN
# Create the model (will have randomly initialised parameters)
model = NN()
# Load the previously saved state dictionary
state_dict = torch.load(PATH)
# Apply the state dictionary to the model
model.load_state_dict(state_dict)
Подробнее о словаре состояний и сохранении / загрузке моделей: PyTorch - Сохранение и загрузка моделей