AttributeError в квантовании модели PyTorch для SequenceTagger Флера

Я работаю над наследуемым классом FlairEmbeddingиз библиотеки Flair NLP. В этом классе я хотел бы реализовать квантование модели с помощью PyTorch's torch.quantizationмодуль. Для этого мне нужно обучить модель на нескольких партиях, чтобы собрать статистику и выбрать правильные параметры квантования. Модель будет использоваться в последующем тэггере последовательности, поэтому я использую Flair's SequenceTaggerкласс с теми же параметрами, что и те, которые я использую в последующей задаче. Вот как выглядит класс:

class CustomEmbeddings(FlairEmbeddings):
    def __init__(
                 self, tag_dictionary, tag_type, corpus, mini_batch_size, train_with_dev, # Used for training
                 model, fine_tune, chars_per_chunk,with_whitespace, tokenized_lm # Base FlairEmbeddings arguments
                ):

        super().__init__(model, fine_tune, chars_per_chunk, with_whitespace, tokenized_lm)
        
        self.lm.qconfig = torch.quantization.default_config
        torch.quantization.prepare(self.lm.qconfig, inplace=True)
        
        # Small training to gather statistics
        tagger = SequenceTagger(hidden_size=256, embeddings=self, tag_dictionary=tag_dictionary, tag_type=tag_type)
        trainer = ModelTrainer(tagger, corpus)
------> trainer.train('model', mini_batch_size=mini_batch_size, max_epochs=10, train_with_dev=train_with_dev)

        torch.quantization.convert(self.lm, inplace=True)

Этот код не работает со следующей ошибкой:

  File "/home/pie3636/project/main.py", line 28, in __init__
    embeddings = CustomEmbeddings(name, **params)
  File "/home/pie3636/project/custom_embeddings.py", line 35, in __init__
    trainer.train('model', mini_batch_size=mini_batch_size, max_epochs=10, train_with_dev=train_with_dev)
  File "/usr/local/lib/python3.6/dist-packages/flair/trainers/trainer.py", line 371, in train
    loss = self.model.forward_loss(batch_step)
  File "/usr/local/lib/python3.6/dist-packages/flair/models/sequence_tagger_model.py", line 603, in forward_loss
    features = self.forward(data_points)
  File "/usr/local/lib/python3.6/dist-packages/flair/models/sequence_tagger_model.py", line 608, in forward
    self.embeddings.embed(sentences)
  File "/usr/local/lib/python3.6/dist-packages/flair/embeddings/base.py", line 60, in embed
    self._add_embeddings_internal(sentences)
  File "/usr/local/lib/python3.6/dist-packages/flair/embeddings/token.py", line 610, in _add_embeddings_internal
    text_sentences, start_marker, end_marker, self.chars_per_chunk
  File "/usr/local/lib/python3.6/dist-packages/flair/models/language_model.py", line 157, in get_representation
    _, rnn_output, hidden = self.forward(batch, hidden)
  File "/usr/local/lib/python3.6/dist-packages/flair/models/language_model.py", line 80, in forward
    output, hidden = self.rnn(emb, hidden)
  File "/home/m.meloux/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 726, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/m.meloux/.local/lib/python3.6/site-packages/torch/quantization/quantize.py", line 74, in _observer_forward_hook
    return self.activation_post_process(output)
  File "/home/m.meloux/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/m.meloux/.local/lib/python3.6/site-packages/torch/quantization/observer.py", line 276, in forward
    x = x_orig.detach()  # avoid keeping autograd tape
AttributeError: 'tuple' object has no attribute 'detach'

Я не уверен, связана ли проблема с моим кодом или о чем нужно сообщить в систему отслеживания проблем PyTorch или Flair. Трассировка стека заставляет меня думать, что это взаимодействие между этими двумя библиотеками дает сбой, а не мой код, тем более что модуль квантования PyTorch все еще находится в стадии бета-тестирования, но я могу ошибаться. Любой ввод относительно того, что может быть ошибка, будет оценен.

0 ответов

Другие вопросы по тегам