pytorch torch.jit.trace возвращает функцию вместо torch.jit.ScriptModule

Мне нужно запустить в C++ предварительно обученную модель Pytorch nn (обученную Python), чтобы делать прогнозы.

Для этого я следую инструкциям по загрузке модели pytorch в C++, приведенной здесь: https://pytorch.org/tutorials/advanced/cpp_export.html

Но когда я пытаюсь получить torch.jit.ScriptModule через трассировку, как указано в первом шаге урока:

    traced_script_module =
        torch.jit.trace(model, (input_tensor_1, input_tensor_2))

Вместо того чтобы возвращать torch.jit.ScriptModule, он возвращает функцию:

    print(type(traced_script_module))
    <type 'function'>

Который, когда я бегу:

    traced_script_module.save("model.pt")

затем приводит к следующей ошибке:

Traceback (most recent call last):
  File "serialize_model.py", line 60, in <module>
    traced_script_module.save("model.pt")
AttributeError: 'function' object has no attribute 'save'

Есть идеи, что я делаю не так?

1 ответ

Решение

Спасибо, что спросили Jatentaki. Я использовал PyTorch 0.4 в Python, и когда я обновился до 1.0, он работал.

Другие вопросы по тегам