Tensorflow, лучший способ сохранить состояние в RNNs?
В настоящее время у меня есть следующий код для серии RNN с цепочками в тензорном потоке. Я не использую MultiRNN, так как позже я должен был что-то сделать с выходом каждого слоя.
for r in range(RNNS):
with tf.variable_scope('recurent_%d' % r) as scope:
state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size]
time_outputs = [None] * TIME_STEPS
for t in range(TIME_STEPS):
rnn_input = getTimeStep(rnn_outputs[r - 1], t)
time_outputs[t], state = rnn_func(rnn_input, state)
time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE))
scope.reuse_variables()
rnn_outputs[r] = tf.concat(1, time_outputs)
В настоящее время у меня есть фиксированное количество временных шагов. Однако я хотел бы изменить его, чтобы он имел только один временной шаг, но помнил состояние между партиями. Поэтому мне нужно создать переменную состояния для каждого слоя и назначить ей конечное состояние каждого из слоев. Что-то вроде этого.
for r in range(RNNS):
with tf.variable_scope('recurent_%d' % r) as scope:
saved_state = tf.get_variable('saved_state', ...)
rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state)
saved_state = tf.assign(saved_state, state)
Затем для каждого из уровней мне нужно будет оценить сохраненное состояние в моей функции sess.run, а также вызвать функцию обучения. Мне нужно сделать это для каждого слоя рН. Это похоже на неприятности. Мне нужно будет отслеживать каждое сохраненное состояние и оценивать его во время выполнения. Кроме того, после запуска потребуется скопировать состояние из моего графического процессора в память хоста, что будет неэффективно и не нужно. Есть ли лучший способ сделать это?
3 ответа
Вот код для обновления начального состояния LSTM, когда state_is_tuple=True
путем определения переменных состояния. Он также поддерживает несколько слоев.
Мы определили две функции - одну для получения переменных состояния с начальным нулевым состоянием и одну функцию для возврата операции, которую мы можем передать session.run
чтобы обновить переменные состояния последним скрытым состоянием LSTM.
def get_state_variables(batch_size, cell):
# For each layer, get the initial state and make a variable out of it
# to enable updating its value.
state_variables = []
for state_c, state_h in cell.zero_state(batch_size, tf.float32):
state_variables.append(tf.contrib.rnn.LSTMStateTuple(
tf.Variable(state_c, trainable=False),
tf.Variable(state_h, trainable=False)))
# Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
return tuple(state_variables)
def get_state_update_op(state_variables, new_states):
# Add an operation to update the train states with the last state tensors
update_ops = []
for state_variable, new_state in zip(state_variables, new_states):
# Assign the new state to the state variables on this layer
update_ops.extend([state_variable[0].assign(new_state[0]),
state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
return tf.tuple(update_ops)
Мы можем использовать это для обновления состояния LSTM после каждой партии. Обратите внимание, что я использую tf.nn.dynamic_rnn
для раскрутки:
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)
# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)
# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})
Основное отличие этого ответа заключается в том, что state_is_tuple=True
делает состояние LSTM LSTMStateTuple, содержащим две переменные (состояние ячейки и скрытое состояние) вместо одной переменной. Использование нескольких слоев делает состояние LSTM кортежем LSTMStateTuples - по одному на уровень.
Сброс на ноль
При использовании обученной модели для прогнозирования / декодирования вы можете сбросить состояние до нуля. Затем вы можете использовать эту функцию:
def get_state_reset_op(state_variables, cell, batch_size):
# Return an operation to set each variable in a list of LSTMStateTuples to zero
zero_states = cell.zero_state(batch_size, tf.float32)
return get_state_update_op(state_variables, zero_states)
Например, как указано выше:
reset_state_op = get_state_reset_op(state, cell, max_batch_size)
# Reset the state to zero before feeding input
sess.run([reset_state_op])
sess.run([outputs, update_op], {data: ...})
Теперь я сохраняю состояния RNN с помощью tf.control_dependencies. Вот пример.
saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())
rnn_output, states = rnn(last_output, saved_states)
with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
dense_input = tf.concat(1, (last_output, rnn_output))
dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
last_output = dense_output + last_output
Я просто проверяю, что часть моего графика зависит от сохранения состояния.
Эти две ссылки также связаны и полезны для этого вопроса:
https://github.com/tensorflow/tensorflow/issues/2695 https://github.com/tensorflow/tensorflow/issues/2838