torch.jit.script (модуль) против декоратора @ torch.jit.script
Почему добавление декоратора "@torch.jit.script" приводит к ошибке, а я могу вызвать torch.jit.script в этом модуле, например, это не удается:
import torch
@torch.jit.script
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
"C:\Users\Administrator\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\jit\__init__.py", line 1262, in script
raise RuntimeError("Type '{}' cannot be compiled since it inherits"
RuntimeError: Type '<class '__main__.MyCell'>' cannot be compiled since it inherits from nn.Module, pass an instance instead
Хотя следующий код работает хорошо:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
Этот вопрос также обсуждается на форумах PyTorch.
1 ответ
Причина вашей ошибки здесь, именно этот пункт:
Нет поддержки наследования или любой другой стратегии полиморфизма, за исключением наследования от объекта для определения класса нового стиля.
Также, как указано вверху:
Поддержка классов TorchScript является экспериментальной. В настоящее время он лучше всего подходит для простых типов записей (подумайте о NamedTuple с прикрепленными методами).
В настоящее время он предназначен для простых классов Python (см. Другие пункты в предоставленной мной ссылке) и функций, см. Ссылку, которую я предоставил для получения дополнительной информации.
Вы также можете проверить torch.jit.script
исходный код, чтобы лучше понять, как это работает.
Кажется, когда вы проходите экземпляр, все attributes
которые следует сохранить, анализируются рекурсивно ( источник). Вы можете следить за этой функцией (довольно прокомментировано, но слишком долго для ответа, см. Здесь), хотя точная причина, почему это так (и почему она была разработана таким образом), находится вне моего понимания (так что, надеюсь, кто-то с опытомtorch.jit
об этом подробнее расскажет внутреннее устройство).