Использование классов TorchScript в качестве членов в модулях pytorch
Я пытаюсь сделать так, чтобы некоторые существующие модели pytorch поддерживали jit-компилятор TorchScript, но у меня возникают проблемы с членами непримитивных типов.
Этот небольшой пример иллюстрирует проблему:
import torch
@torch.jit.script
class Factory(object):
def __init__(self):
pass
def create(self, x: float) -> torch.Tensor:
return torch.tensor([x])
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.factory: Factory = Factory()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.factory.create(0)
mod = torch.jit.script(Foo())
При запуске компилятор jit выдает ошибку
RuntimeError:
module has no attribute 'factory':
at example.py:17:15
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.factory.create(0)
~~~~~~~~~~~~ <--- HERE
Я проверил, что Factory
класс доступен для jit внутри forward
метод, но он не подтверждает его, когда я сохраняю его как член. Почему это? И есть ли способ заставить компилятор jit сохранять такие элементы в скомпилированном модуле?
0 ответов
Это была ошибка в PyTorch, которая была устранена вскоре после того, как вы разместили свой вопрос: https://discuss.pytorch.org/t/jit-scripted-attributes-inside-module/60645, https://github.com/pytorch/pytorch/issues/27495.
Обновление PyTorch должно исправить это.