Как декодировать вывод seq2seq?

Код из примера Tensorflow translate.py меня смутил. Скопированный код:

  # This is a greedy decoder - outputs are just argmaxes of output_logits.
  outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]

Почему argmax Работа?

output_logits форма [bucket_length,batch_size,embedding_size]

1 ответ

Решение

Для каждого логита (или: активации для каждого слова) они берут индекс, где активация имеет самое высокое значение из всех.

Для argmax: взгляните на примеры на этой странице: https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html

a = array([[0, 1, 2],
       [3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])

Итак, что делает вывод:

  • Для каждого слова (длина bucket_length)
    • получить максимальную активацию embedding_size

Вы должны посмотреть на форму полученного массива выходных данных. Вы увидите, что, поскольку batch_size равен 1, все работает!

Дайте мне знать, если это поможет вам!

Другие вопросы по тегам