Двунаправленный GRU TensorFlow возвращает ValueError из-за <предположительно> неправильной формы

Я реализую двунаправленную сеть маркировки GRU (1 уровень вперед, 1 уровень назад), используя версию 0.9 TensorFlow. После инициализации модели TensorFlow инициализирует все переменные, создает ячейки GRU и корректно применяет все регулярные преобразования, пока не наступит время для запуска tf.nn.bidirectional_rnn функция, в которой выдается ошибка ValueError, связанная с неправильной операцией слияния Tensor. Вот код:

# Create the cells
with tf.variable_scope('forward'):
    self.char_gru_cell_fw = tf.nn.rnn_cell.GRUCell(char_hidden_size)
with tf.variable_scope('backward'):
    self.char_gru_cell_bw = tf.nn.rnn_cell.GRUCell(char_hidden_size)

# Set initial state of the cells to be zero
self._char_initial_state_fw = \
    self.char_gru_cell_fw.zero_state(batch_size, tf.float32)
self._char_initial_state_bw = \
    self.char_gru_cell_bw.zero_state(batch_size, tf.float32)

#         Size before: batch-chrpad-chrvocabsize
#          Size after: batch-chrvocabsize
chargruinput = [tf.squeeze(input_, [1]) \
    for input_ in tf.split(1, char_num_steps, chargruinput)]

# Run the bidirectional rnn and get the corner results
_, output_state_fw, output_state_bw = \
   tf.nn.bidirectional_rnn(self.char_gru_cell_fw, 
                    self.char_gru_cell_bw, 
                    chargruinput, 
                    sequence_length=char_num_steps,
                    initial_state_fw=self._char_initial_state_fw,
                    initial_state_bw=self._char_initial_state_bw)

Когда я запускаю это, я получаю следующую ошибку:

Traceback (most recent call last):
  File "frontbackgru.py", line 409, in <module>
    main()
  File "frontbackgru.py", line 226, in main
    config=my_config)
  File "/home/xG/Code/4-RNN/1-simple-cnn-input-classifier/gru_model.py", line 265, in __init__
    initial_state_bw=self._char_initial_state_bw)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 453, in bidirectional_rnn
    sequence_length, scope=fw_scope)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 156, in rnn
    state_size=cell.state_size)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 343, in _rnn_step
    _maybe_copy_some_through)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1331, in cond
    _, res_f = context_f.BuildCondBranch(fn2)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1230, in BuildCondBranch
    r = fn()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 317, in _maybe_copy_some_through
    lambda: _copy_some_through(new_output, new_state))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1331, in cond
    _, res_f = context_f.BuildCondBranch(fn2)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1230, in BuildCondBranch
    r = fn()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 317, in <lambda>
    lambda: _copy_some_through(new_output, new_state))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 298, in _copy_some_through
    return ([math_ops.select(copy_cond, zero_output, new_output)]
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_math_ops.py", line 1769, in select
    name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 704, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2262, in create_op
    set_shapes_for_outputs(ret)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1702, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_ops.py", line 1578, in _SelectShape
    t_e_shape = t_e_shape.merge_with(c_shape)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_shape.py", line 570, in merge_with
    (self, other))
ValueError: Shapes (32, 50) and () are not compatible

Теперь входы в bidirectional_rnn функции являются:

self.char_gru_cell_fw: Это экземпляр GRUCell, инициализированный с целочисленным значением char_hidden_size50 в этом случае

self.char_gru_cell_bw: Это экземпляр GRUCell, инициализированный с целочисленным значением char_hidden_size50 в этом случае

chargruinput: Это список длиной 30, содержащий тензоры формы [batch_size,charvocab], что в данном случае составляет [32,256]

sequence_length: целое число, представляющее количество развернутых ячеек, char_num_steps, который составляет 30 в этом случае.

initial_state_fw: заполненный нулями тензор той же формы, что и состояние ГРУ, [32,50] в этом случае

initial_state_bw: заполненный нулями тензор той же формы, что и состояние ГРУ, [32,50] в этом случае

Я попытался просмотреть модули, которые привели к возникновению исключения ValueError, но там происходит много вещей низкого уровня, которые, скорее всего, работают нормально, видя, как работает CNN, над которой я работал на прошлой неделе, без проблем. Это заставляет меня думать, что до низкоуровневых методов в rnn или же rnn_cell Библиотека, которую я не использовал раньше.

Это также кажется странным, поскольку ошибка связана с пустой формой (я предполагаю, что она связана со скаляром вместо тензора), но единственное, что я могу изменить, это скаляр в bidirectional_rnn аргументы функции является sequence_length аргумент. Я попытался опустить его и использовать только начальные состояния, и наоборот, но появляется та же ошибка.

У кого-нибудь была похожая проблема? Вся моя система повреждена этим, хотелось бы получить обратную связь. заранее спасибо

1 ответ

Разобрался что не так - аргумент sequence_length на самом деле должен быть список целых чисел длины batch_size для каждой партии, а не целое число. Легко исправить, спасибо за игру:)

Другие вопросы по тегам