Использование классов TorchScript в качестве членов в модулях pytorch

Я пытаюсь сделать так, чтобы некоторые существующие модели pytorch поддерживали jit-компилятор TorchScript, но у меня возникают проблемы с членами непримитивных типов.

Этот небольшой пример иллюстрирует проблему:

import torch

@torch.jit.script
class Factory(object):
    def __init__(self):
        pass

    def create(self, x: float) -> torch.Tensor:
        return torch.tensor([x])

class Foo(torch.nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.factory: Factory = Factory()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.factory.create(0)

mod = torch.jit.script(Foo())

При запуске компилятор jit выдает ошибку

RuntimeError:
module has no attribute 'factory':
at example.py:17:15
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.factory.create(0)
               ~~~~~~~~~~~~ <--- HERE

Я проверил, что Factory класс доступен для jit внутри forwardметод, но он не подтверждает его, когда я сохраняю его как член. Почему это? И есть ли способ заставить компилятор jit сохранять такие элементы в скомпилированном модуле?

0 ответов

Это была ошибка в PyTorch, которая была устранена вскоре после того, как вы разместили свой вопрос: https://discuss.pytorch.org/t/jit-scripted-attributes-inside-module/60645, https://github.com/pytorch/pytorch/issues/27495.

Обновление PyTorch должно исправить это.

Другие вопросы по тегам