Как загрузить контрольную точку обучения Chainer из npz?
Я использую Chainer для обучения (точной настройки) модели Resnet, а затем использую контрольную точку для оценки. Контрольная точка представляет собой файл npz со следующей структурой:
Когда я загружаю модель для оценки с помощью chainer.serializers.load_npz(args.load, model)
(где модель - это стандартный реснет) Я получаю следующую ошибку: KeyError: "rpn/loc/b не является файлом в архиве".
Я думаю, проблема в том, что файлы в модели не имеют префикса "updater/optimizer/fasten /extractor".
Как я могу изменить имя файлов в полученном npz, чтобы удалить префикс или что еще мне сделать, чтобы исправить проблему?
Спасибо!
1 ответ
Когда вы загружаете снимок, сгенерированный расширением Snapshot Extension, вам нужно сделать это из трейнера.
chainer.serializers.load_npz(args.load, trainer)
Тренер автоматически загрузит состояние программы обновления, оптимизатора и модели.
Вы также можете загрузить только модель вручную, обратившись к соответствующему полю в снимке и передав его в качестве аргумента в model.serialize
функция
npz_data = numpy.load(args.load)
snap = chainer.serializers.NpzDeserializer(npz_data)
model.serialize(snap['updater']['model:main'])
Это должно загрузить только веса модели.