Как сбросить состояние ГРУ в тензорном потоке после каждой эпохи

Я использую ячейку tenorflow GRU для реализации RNN. Я использую вышеупомянутое с видео, которое длится максимум 5 минут. Поэтому, поскольку следующее состояние автоматически подается в GRU, как я могу вручную сбросить состояние RNN после каждой эпохи. Другими словами, я хочу, чтобы начальное состояние в начале обучения всегда было 0. Вот фрагмент моего кода:

with tf.variable_scope('GRU'):
    latent_var = tf.reshape(latent_var, shape=[batch_size, time_steps, latent_dim])

    cell = tf.nn.rnn_cell.GRUCell(cell_size)   
    H, C = tf.nn.dynamic_rnn(cell, latent_var, dtype=tf.float32)  
    H = tf.reshape(H, [batch_size, cell_size]) 
....

Любая помощь высоко ценится!

1 ответ

Решение

Использование initial_state аргумент tf.nn.dynamic_rnn:

initial_state: (необязательно) Начальное состояние для RNN. Если cell.state_size является целым числом, это должен быть Тензор соответствующего типа и формы [batch_size, cell.state_size], Если cell.state_siz е - это кортеж, это должен быть кортеж из тензоров, имеющий форму [batch_size, s] for s in cell.state_size,

Адаптированный пример из документации:

# create a GRUCell
cell = tf.nn.rnn_cell.GRUCell(cell_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

Также обратите внимание, что несмотря на initial_state не будучи местозаполнителем, вы также можете передать ему значение. Поэтому, если вы хотите сохранить состояние в пределах эпохи, но начать с нуля в начале эпохи, вы можете сделать это следующим образом:

# Compute the zero state array of the right shape once
zero_state = sess.run(initial_state)

# Start with a zero vector and update it 
cur_state = zero_state
for batch in get_batches():
  cur_state, _ = sess.run([state, ...], feed_dict={initial_state=cur_state, ...})