torch.jit.script (модуль) против декоратора @ torch.jit.script

Почему добавление декоратора "@torch.jit.script" приводит к ошибке, а я могу вызвать torch.jit.script в этом модуле, например, это не удается:

import torch

@torch.jit.script
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

"C:\Users\Administrator\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\jit\__init__.py", line 1262, in script
    raise RuntimeError("Type '{}' cannot be compiled since it inherits"
RuntimeError: Type '<class '__main__.MyCell'>' cannot be compiled since it inherits from nn.Module, pass an instance instead

Хотя следующий код работает хорошо:

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

Этот вопрос также обсуждается на форумах PyTorch.

1 ответ

Причина вашей ошибки здесь, именно этот пункт:

Нет поддержки наследования или любой другой стратегии полиморфизма, за исключением наследования от объекта для определения класса нового стиля.

Также, как указано вверху:

Поддержка классов TorchScript является экспериментальной. В настоящее время он лучше всего подходит для простых типов записей (подумайте о NamedTuple с прикрепленными методами).

В настоящее время он предназначен для простых классов Python (см. Другие пункты в предоставленной мной ссылке) и функций, см. Ссылку, которую я предоставил для получения дополнительной информации.

Вы также можете проверить torch.jit.scriptисходный код, чтобы лучше понять, как это работает.

Кажется, когда вы проходите экземпляр, все attributesкоторые следует сохранить, анализируются рекурсивно ( источник). Вы можете следить за этой функцией (довольно прокомментировано, но слишком долго для ответа, см. Здесь), хотя точная причина, почему это так (и почему она была разработана таким образом), находится вне моего понимания (так что, надеюсь, кто-то с опытомtorch.jitоб этом подробнее расскажет внутреннее устройство).

Другие вопросы по тегам