pytorch torch.jit.trace возвращает функцию вместо torch.jit.ScriptModule
Мне нужно запустить в C++ предварительно обученную модель Pytorch nn (обученную Python), чтобы делать прогнозы.
Для этого я следую инструкциям по загрузке модели pytorch в C++, приведенной здесь: https://pytorch.org/tutorials/advanced/cpp_export.html
Но когда я пытаюсь получить torch.jit.ScriptModule через трассировку, как указано в первом шаге урока:
traced_script_module =
torch.jit.trace(model, (input_tensor_1, input_tensor_2))
Вместо того чтобы возвращать torch.jit.ScriptModule, он возвращает функцию:
print(type(traced_script_module))
<type 'function'>
Который, когда я бегу:
traced_script_module.save("model.pt")
затем приводит к следующей ошибке:
Traceback (most recent call last):
File "serialize_model.py", line 60, in <module>
traced_script_module.save("model.pt")
AttributeError: 'function' object has no attribute 'save'
Есть идеи, что я делаю не так?
1 ответ
Решение
Спасибо, что спросили Jatentaki. Я использовал PyTorch 0.4 в Python, и когда я обновился до 1.0, он работал.