как получить информацию о выравнивании или внимании для переводов, выполненных моделью втулки резака?
Torch Hub предоставляет предварительно обученные модели, например: https://pytorch.org/hub/pytorch_fairseq_translation/
Эти модели можно использовать в Python или интерактивно с помощью интерфейса командной строки. С помощью CLI можно получить согласование, с--print-alignment
флаг. Следующий код работает в терминале после установки fairseq (и pytorch)
curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
MODEL_DIR=wmt14.en-fr.fconv-py
fairseq-interactive \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --source-lang en --target-lang fr \
--tokenizer moses \
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes \
--print-alignment
В python можно указать ключевое слово args verbose
а также print_alignment
:
import torch
en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')
fr = en2fr.translate('Hello world!', beam=5, verbose=True, print_alignment=True)
Однако это будет выводить выравнивание только в виде сообщения журнала. А для fairseq 0.9 он, кажется, не работает, что приводит к сообщению об ошибке (проблеме).
Есть ли способ получить доступ к информации о выравнивании (или, возможно, даже к матрице полного внимания) из кода Python?
1 ответ
Я просмотрел кодовую базу fairseq и нашел хитрый способ вывода информации о выравнивании. Поскольку это требует редактирования самого исходного кода fairseq, я не думаю, что это приемлемое решение. Но, возможно, это кому-то поможет (мне все еще очень интересно ответить, как это сделать правильно).
Отредактируйте функцию sample() и перепишите оператор возврата. Вот вся функция (чтобы помочь вам найти ее лучше в коде), но следует изменить только последнюю строку:
def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return list(zip([self.decode(hypos[0]['tokens']) for hypos in batched_hypos], [hypos[0]['alignment'] for hypos in batched_hypos]))