как получить информацию о выравнивании или внимании для переводов, выполненных моделью втулки резака?

Torch Hub предоставляет предварительно обученные модели, например: https://pytorch.org/hub/pytorch_fairseq_translation/

Эти модели можно использовать в Python или интерактивно с помощью интерфейса командной строки. С помощью CLI можно получить согласование, с--print-alignmentфлаг. Следующий код работает в терминале после установки fairseq (и pytorch)

curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
MODEL_DIR=wmt14.en-fr.fconv-py
fairseq-interactive \
    --path $MODEL_DIR/model.pt $MODEL_DIR \
    --beam 5 --source-lang en --target-lang fr \
    --tokenizer moses \
    --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes \ 
    --print-alignment

В python можно указать ключевое слово args verbose а также print_alignment:

import torch

en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')

fr = en2fr.translate('Hello world!', beam=5, verbose=True, print_alignment=True)

Однако это будет выводить выравнивание только в виде сообщения журнала. А для fairseq 0.9 он, кажется, не работает, что приводит к сообщению об ошибке (проблеме).

Есть ли способ получить доступ к информации о выравнивании (или, возможно, даже к матрице полного внимания) из кода Python?

1 ответ

Я просмотрел кодовую базу fairseq и нашел хитрый способ вывода информации о выравнивании. Поскольку это требует редактирования самого исходного кода fairseq, я не думаю, что это приемлемое решение. Но, возможно, это кому-то поможет (мне все еще очень интересно ответить, как это сделать правильно).

Отредактируйте функцию sample() и перепишите оператор возврата. Вот вся функция (чтобы помочь вам найти ее лучше в коде), но следует изменить только последнюю строку:

def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
    if isinstance(sentences, str):
        return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
    tokenized_sentences = [self.encode(sentence) for sentence in sentences]
    batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
    return list(zip([self.decode(hypos[0]['tokens']) for hypos in batched_hypos], [hypos[0]['alignment'] for hypos in batched_hypos]))
Другие вопросы по тегам