Почему я не могу использовать классы для аннотации типов аргументов функции в декораторе `torch.jit.script`?
Этот код отлично компилируется:
import torch
import torch.nn as nn
class Foo(nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.x = 0
def forward(self, X):
X *= self.x
self.x += 1
return X
# @torch.jit.script
def bar(f: Foo):
return f.x
Но если я раскомментирую # @torch.jit.script
строка, я получаю эту ошибку:
Traceback (most recent call last):
File "test1.py", line 18, in <module>
def bar(f: Foo):
File "/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/jit/__init__.py", line 1103, in script
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
RuntimeError:
Unknown type name 'Foo':
at test1.py:18:12
@torch.jit.script
def bar(f: Foo):
~~~ <--- HERE
return f.x
Если я изменю аннотацию типа на int
:
@torch.jit.script
# def bar(f: Foo):
# return f.x
def bar(f: int):
return f
затем компиляция снова работает.
Кто-нибудь знает, что мне нужно сделать, чтобы разрешить использование моих пользовательских определений классов в аннотациях типов к аргументам функций, которые находятся под torch.jit.script
декоратор?
1 ответ
Решение
В качестве аргументов функций можно использовать только список типов в документации:
https://pytorch.org/docs/stable/jit_language_reference.html
nn.Module
В TorchScript есть специальная обработка, чтобы заставить их работать, но в настоящее время они не поддерживаются в качестве аргументов.