Невозможно понять tf.nn.raw_rnn
В официальной документации tf.nn.raw_rnn
у нас есть структура излучения в качестве третьего выхода loop_fn
когда loop_fn
запускается впервые.
Позже emit_structure используется для копирования tf.zeros_like(emit_structure)
в записи мини-пакетов, которые завершены emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
,
мое непонимание или паршивая документация со стороны Google: структура emit None
так tf.where(finished, tf.zeros_like(emit_structure), emit)
собирается бросить ValueError как tf.zeros_like(None)
делает так. Может кто-нибудь, пожалуйста, заполните то, что мне здесь не хватает?
1 ответ
Да, документ довольно запутанный в этом месте. Если вы посмотрите на внутренности tf.nn.raw_rnn
ключевой термин "в псевдокоде", поэтому пример в документе не является точным.
Точный исходный код выглядит следующим образом (может отличаться в зависимости от версии тензорного потока):
if emit_structure is not None:
flat_emit_structure = nest.flatten(emit_structure)
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
array_ops.shape(emit) for emit in flat_emit_structure]
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
emit_structure = cell.output_size
flat_emit_size = nest.flatten(emit_structure)
flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
Так что это обрабатывает случай, когда emit_structure is None
и просто берет значение cell.output_size
, Вот почему ничего не сломалось.