Как сделать для каждого набора элементов состояния lstm с помощью упаковщика внимания (собирать TensorArray)

Поскольку в AttentionWrapper.py из contrib/seq2seq, AttentionWrapperState имеет историю выравнивания, которая имеет тип TensorArray. Я хочу собирать при выполнении поиска луча, чтобы изменить порядок состояний, в настоящее время я не могу сделать это, когда у него есть TensorArray..

#TODO not support tf.TensorArray right now, can not use alignment_history in attention_wrapper
def try_gather(x, indices):
  #if isinstance(x, tf.Tensor) and x.shape.ndims >= 2:
  assert isinstance(x, tf.Tensor)
  if x.shape.ndims >= 2:
    return tf.gather(x, indices)
  else:
    return x

#must reoder state! including attention
state = nest.map_structure(lambda x: try_gather(x, flattened_past_indices), state)

0 ответов

Другие вопросы по тегам