Как загрузить контрольную точку обучения Chainer из npz?

Я использую Chainer для обучения (точной настройки) модели Resnet, а затем использую контрольную точку для оценки. Контрольная точка представляет собой файл npz со следующей структурой:

Когда я загружаю модель для оценки с помощью chainer.serializers.load_npz(args.load, model) (где модель - это стандартный реснет) Я получаю следующую ошибку: KeyError: "rpn/loc/b не является файлом в архиве".

Я думаю, проблема в том, что файлы в модели не имеют префикса "updater/optimizer/fasten /extractor".

Как я могу изменить имя файлов в полученном npz, чтобы удалить префикс или что еще мне сделать, чтобы исправить проблему?

Спасибо!

1 ответ

Когда вы загружаете снимок, сгенерированный расширением Snapshot Extension, вам нужно сделать это из трейнера.

chainer.serializers.load_npz(args.load, trainer) Тренер автоматически загрузит состояние программы обновления, оптимизатора и модели.

Вы также можете загрузить только модель вручную, обратившись к соответствующему полю в снимке и передав его в качестве аргумента в model.serialize функция

npz_data = numpy.load(args.load)
snap = chainer.serializers.NpzDeserializer(npz_data)
model.serialize(snap['updater']['model:main'])

Это должно загрузить только веса модели.

Другие вопросы по тегам