Получение выравнивания / внимания при переводе в OpenNMT-py
Кто-нибудь знает, как получить вес выравнивания при переводе в Opennmt-py? Обычно единственным выходом являются полученные предложения, и я пытался найти флаг отладки или аналогичный для весов внимания. До сих пор у меня не получилось.
1 ответ
Вы можете получить матрицы внимания. Обратите внимание, что это не то же самое, что выравнивание, которое является термином из статистического (не нейронного) машинного перевода.
На github есть ветка, обсуждающая это. Вот фрагмент из обсуждения. Когда вы получаете переводы из режима, внимание обращено на attn
поле.
import onmt
import onmt.io
import onmt.translate
import onmt.ModelConstructor
from collections import namedtuple
# Load the model.
Opt = namedtuple('Opt', ['model', 'data_type', 'reuse_copy_attn', "gpu"])
opt = Opt("PATH_TO_SAVED_MODEL", "text", False, 0)
fields, model, model_opt = onmt.ModelConstructor.load_test_model(
opt, {"reuse_copy_attn" : False})
# Test data
data = onmt.io.build_dataset(
fields, "text", "PATH_TO_DATA", None, use_filter_pred=False)
data_iter = onmt.io.OrderedIterator(
dataset=data, device=0,
batch_size=1, train=False, sort=False,
sort_within_batch=True, shuffle=False)
# Translator
translator = onmt.translate.Translator(
model, fields, beam_size=5, n_best=1,
global_scorer=None, cuda=True)
builder = onmt.translate.TranslationBuilder(
data, translator.fields, 1, False, None)
batch = next(data_iter)
batch_data = translator.translate_batch(batch, data)
translations = builder.from_batch(batch_data)
translations[0].attn # <--- here are the attentions
Я не уверен, что это новая функция, так как я не сталкивался с этим при поиске выравниваний несколько месяцев назад, но похоже, что onmt добавил флаг -report_align
для вывода выравнивания слов вместе с переводом.
https://opennmt.net/OpenNMT-py/FAQ.html#raw-alignments-from-averaging-transformer-attention-heads
Выдержка из opennnmt.net -
В настоящее время мы поддерживаем выравнивание слов при переводе моделей на основе Transformer. Использование -report_align при вызове translate.py выведет предполагаемые выравнивания в формате Pharaoh. Эти выравнивания вычисляются из argmax в среднем для заголовков внимания со второго до последнего уровня декодера.