Невозможно понять tf.nn.raw_rnn

В официальной документации tf.nn.raw_rnn у нас есть структура излучения в качестве третьего выхода loop_fn когда loop_fn запускается впервые.

Позже emit_structure используется для копирования tf.zeros_like(emit_structure) в записи мини-пакетов, которые завершены emit = tf.where(finished, tf.zeros_like(emit_structure), emit),

мое непонимание или паршивая документация со стороны Google: структура emit None так tf.where(finished, tf.zeros_like(emit_structure), emit) собирается бросить ValueError как tf.zeros_like(None) делает так. Может кто-нибудь, пожалуйста, заполните то, что мне здесь не хватает?

1 ответ

Решение

Да, документ довольно запутанный в этом месте. Если вы посмотрите на внутренности tf.nn.raw_rnnключевой термин "в псевдокоде", поэтому пример в документе не является точным.

Точный исходный код выглядит следующим образом (может отличаться в зависимости от версии тензорного потока):

if emit_structure is not None:
  flat_emit_structure = nest.flatten(emit_structure)
  flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                    array_ops.shape(emit) for emit in flat_emit_structure]
  flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
  emit_structure = cell.output_size
  flat_emit_size = nest.flatten(emit_structure)
  flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

Так что это обрабатывает случай, когда emit_structure is None и просто берет значение cell.output_size, Вот почему ничего не сломалось.

Другие вопросы по тегам