Tensorflow, лучший способ сохранить состояние в RNNs?

В настоящее время у меня есть следующий код для серии RNN с цепочками в тензорном потоке. Я не использую MultiRNN, так как позже я должен был что-то сделать с выходом каждого слоя.

 for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size]
        time_outputs = [None] * TIME_STEPS

        for t in range(TIME_STEPS):
            rnn_input = getTimeStep(rnn_outputs[r - 1], t)
            time_outputs[t], state = rnn_func(rnn_input, state)
            time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE))
            scope.reuse_variables()
        rnn_outputs[r] = tf.concat(1, time_outputs)

В настоящее время у меня есть фиксированное количество временных шагов. Однако я хотел бы изменить его, чтобы он имел только один временной шаг, но помнил состояние между партиями. Поэтому мне нужно создать переменную состояния для каждого слоя и назначить ей конечное состояние каждого из слоев. Что-то вроде этого.

for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        saved_state = tf.get_variable('saved_state', ...)
        rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state)
        saved_state = tf.assign(saved_state, state)

Затем для каждого из уровней мне нужно будет оценить сохраненное состояние в моей функции sess.run, а также вызвать функцию обучения. Мне нужно сделать это для каждого слоя рН. Это похоже на неприятности. Мне нужно будет отслеживать каждое сохраненное состояние и оценивать его во время выполнения. Кроме того, после запуска потребуется скопировать состояние из моего графического процессора в память хоста, что будет неэффективно и не нужно. Есть ли лучший способ сделать это?

3 ответа

Решение

Вот код для обновления начального состояния LSTM, когда state_is_tuple=True путем определения переменных состояния. Он также поддерживает несколько слоев.

Мы определили две функции - одну для получения переменных состояния с начальным нулевым состоянием и одну функцию для возврата операции, которую мы можем передать session.run чтобы обновить переменные состояния последним скрытым состоянием LSTM.

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    return tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    return tf.tuple(update_ops)

Мы можем использовать это для обновления состояния LSTM после каждой партии. Обратите внимание, что я использую tf.nn.dynamic_rnn для раскрутки:

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)

# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)

# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)

# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})

Основное отличие этого ответа заключается в том, что state_is_tuple=True делает состояние LSTM LSTMStateTuple, содержащим две переменные (состояние ячейки и скрытое состояние) вместо одной переменной. Использование нескольких слоев делает состояние LSTM кортежем LSTMStateTuples - по одному на уровень.

Сброс на ноль

При использовании обученной модели для прогнозирования / декодирования вы можете сбросить состояние до нуля. Затем вы можете использовать эту функцию:

def get_state_reset_op(state_variables, cell, batch_size):
    # Return an operation to set each variable in a list of LSTMStateTuples to zero
    zero_states = cell.zero_state(batch_size, tf.float32)
    return get_state_update_op(state_variables, zero_states)

Например, как указано выше:

reset_state_op = get_state_reset_op(state, cell, max_batch_size)
# Reset the state to zero before feeding input
sess.run([reset_state_op])
sess.run([outputs, update_op], {data: ...})

Теперь я сохраняю состояния RNN с помощью tf.control_dependencies. Вот пример.

 saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
        W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
        b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())

        rnn_output, states = rnn(last_output, saved_states)
        with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
            dense_input = tf.concat(1, (last_output, rnn_output))

        dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
        last_output = dense_output + last_output

Я просто проверяю, что часть моего графика зависит от сохранения состояния.

Эти две ссылки также связаны и полезны для этого вопроса:

https://github.com/tensorflow/tensorflow/issues/2695 https://github.com/tensorflow/tensorflow/issues/2838

Другие вопросы по тегам