Производный класс Pytorch nn.Module не может быть загружен с помощью импорта модуля в Python

Использование Python 3.6 с Pytorch 1.3.1. Я заметил, что некоторые сохраненные модули nn.Modules не могут быть загружены, когда весь модуль импортируется в другой модуль. В качестве примера, вот шаблон минимального рабочего примера.

#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'

from torch import nn
class NN(nn.Module):##NN network
    # Initialisation and other class methods

networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
    # Some testing snippets
    pass

Когда я запускаю его напрямую в оболочке, весь файл работает нормально. Однако когда я хочу использовать класс и загрузить нейронную сеть в другой файл с помощью этого кода, это не удается.

#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *

Ошибка гласит AttributeError: Can't get attribute 'NN' on <module '__main__'>

В Pytorch загрузка сохраненных переменных или импорт модулей происходит иначе, чем в других распространенных библиотеках Python? Некоторая помощь или указатель на основную причину будут действительно признательны.

1 ответ

Решение

Когда вы сохраняете модель с torch.save(model, PATH) весь объект сериализуется с помощью pickle, который сохраняет не сам класс, а путь к файлу, содержащему класс, поэтому при загрузке модели требуются точно такие же каталог и файловая структура, чтобы найти правильный класс. При запуске скрипта Python модуль этого файла__main__, поэтому, если вы хотите загрузить этот модуль, ваш NN class должен быть определен в скрипте, который вы запускаете.

Это очень негибкий подход, поэтому рекомендуется не сохранять всю модель, а вместо этого просто сохранять словарь состояний, который сохраняет только параметры модели.

# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)

После этого словарь состояний можно загрузить и применить к вашей модели.

from dnn_predict import NN

# Create the model (will have randomly initialised parameters)
model = NN()

# Load the previously saved state dictionary
state_dict = torch.load(PATH)

# Apply the state dictionary to the model
model.load_state_dict(state_dict)

Подробнее о словаре состояний и сохранении / загрузке моделей: PyTorch - Сохранение и загрузка моделей

Другие вопросы по тегам