Почему параметры, сохраненные в КПП, отличаются от параметров в объединенной модели?
Я тренировал QAT
(Обучение с учетом квантования) на основе модели в Pytorch
, обучение прошло гладко. Однако когда я попытался загрузить веса в объединенную модель и запустить тест на наборе данных с более широким кругом, я столкнулся с множеством ошибок:
(base) marian@u04-2:/mnt/s3user/Pytorch_Retinaface_quantized# python test_widerface.py --trained_model ./weights/mobilenet0.25_Final_quantized.pth --network mobile0.25layers:
Loading pretrained model from ./weights/mobilenet0.25_Final_quantized.pth
remove prefix 'module.'
Missing keys:235
Unused checkpoint keys:171
Used keys:65
Traceback (most recent call last):
File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/ptvsd_launcher.py", line 43, in <module>
main(ptvsdArgs)
File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 432, in main
run()
File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 316, in run_file
runpy.run_path(target, run_name='__main__')
File "/root/anaconda3/lib/python3.7/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/root/anaconda3/lib/python3.7/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/root/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/mnt/f3user/Pytorch_Retinaface_quantized/test_widerface.py", line 114, in <module>
net = load_model(net, args.trained_model, args.cpu)
File "/mnt/f3user/Pytorch_Retinaface_quantized/test_widerface.py", line 95, in load_model
model.load_state_dict(pretrained_dict, strict=False)
File "/root/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for RetinaFace:
While copying the parameter named "ssh1.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).
While copying the parameter named "ssh1.conv5X5_2.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
While copying the parameter named "ssh1.conv7x7_3.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
While copying the parameter named "ssh2.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).
While copying the parameter named "ssh2.conv5X5_2.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
.....
Полный список можно найти здесь.
в основном веса не найти. плюс масштаб и zero_point, которые отсутствуют в объединенной модели.
в случае, если это имеет значение, следующий фрагмент представляет собой фактический цикл обучения, который использовался для обучения и сохранения модели:
if __name__ == '__main__':
# train()
...
net = RetinaFace(cfg=cfg)
print("Printing net...")
print(net)
net.fuse_model()
...
net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(net, inplace=True)
print(f'quantization preparation done.')
...
quantized_model = net
for i in range(max_epoch):
net = net.to(device)
train_one_epoch(net, data_loader, optimizer, criterion, cfg, gamma, i, step_index, device)
if i in stepvalues:
step_index += 1
if i > 3 :
net.apply(torch.quantization.disable_observer)
if i > 2 :
net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
net=net.cpu()
quantized_model = torch.quantization.convert(net.eval(), inplace=False)
quantized_model.eval()
# evaluate on test set ?!
torch.save(net.state_dict(), save_folder + cfg['name'] + '_Final.pth')
torch.save(quantized_model.state_dict(), save_folder + cfg['name'] + '_Final_quantized.pth')
#torch.jit.save(torch.jit.script(quantized_model), save_folder + cfg['name'] + '_Final_quantized_jit.pth')
для тестирования test_widerface.py
используется, доступ к которому можно получить здесь
Вы можете просмотреть ключи здесь
Почему это произошло? Как об этом позаботиться?
Обновить
Я проверил имя, создал новый словарь state_dict и вставил 112 ключей, которые были как в контрольной точке, так и в модели, используя фрагмент ниже:
new_state_dict = {}
checkpoint_state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
for (ck, cp) in checkpoint_state_dict.items():
for (mk, mp) in model.state_dict().items():
kname,kext = os.path.splitext(ck)
mname,mext = os.path.splitext(mk)
# check the two parameter and see if they are the same
# then use models key naming scheme and use checkpoints weights
if kname+kext == mname+mext or kname+'.0'+kext == mname+mext:
new_state_dict[mname+mext] = cp
else:
if kext in ('.scale','.zero_point'):
new_state_dict[ck] = cp
а затем используйте этот новый state_dict! но я получаю те же самые точные ошибки! что означает такие ошибки:
RuntimeError: Error(s) in loading state_dict for RetinaFace:
While copying the parameter named "ssh1.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).
Это действительно неприятно, и документации по этому поводу нет! Я здесь совершенно ничего не понимаю.
1 ответ
Я наконец выяснил причину. Сообщения об ошибках в виде:
При копировании параметра с именем "xxx.weight", размеры которого в модели - torch.Size([yyy]), а размеры в контрольной точке - torch.Size([yyy]).
фактически являются общими сообщениями, возвращаемыми только в том случае, если при копировании рассматриваемых параметров возникло исключение.
Разработчики Pytorch могут легко добавить фактические аргументы исключения в это ложное, но бесполезное сообщение, так что оно действительно может помочь лучше отладить проблему. Во всяком случае, глядя на исключение, которое было между прочим:
"copy_" not implemented for \'QInt8'
Теперь вы будете знать, в чем заключается реальная проблема!