Как я могу получить доступ к весам повторяющейся ячейки в Tensorflow?

Одним из способов повышения стабильности в задачах глубокого Q-обучения является поддержание набора целевых весов для сети, которые обновляются медленно и используются для расчета целевых значений Q-значения. В результате в разное время в процедуре обучения в прямом проходе используются два разных набора весов. Для обычного DQN это не сложно реализовать, так как весовые коэффициенты являются переменными тензорного потока, которые могут быть установлены в feed_dict, т.е.

sess = tf.Session()
input = tf.placeholder(tf.float32, shape=[None, 5])
weights = tf.Variable(tf.random_normal(shape=[5,4], stddev=0.1)
bias = tf.Variable(tf.constant(0.1, shape=[4])
output = tf.matmul(input, weights) + bias
target = tf.placeholder(tf.float32, [None, 4])
loss = ...

...

#Here we explicitly set weights to be the slowly updated target weights
sess.run(output, feed_dict={input: states, weights: target_weights, bias: target_bias})

# Targets for the learning procedure are computed using this output.

....

#Now we run the learning procedure, using the most up to date weights,
#as well as the previously computed targets
sess.run(loss, feed_dict={input: states, target: targets})

Я хотел бы использовать эту целевую сетевую технику в рекуррентной версии DQN, но я не знаю, как получить доступ и установить веса, используемые внутри рекуррентной ячейки. В частности, я использую tf.nn.rnn_cell.BasicLSTMCell, но я хотел бы знать, как это сделать для любого типа рекуррентной ячейки.

2 ответа

Решение

BasicLSTMCell не предоставляет свои переменные как часть своего открытого API. Я рекомендую вам либо посмотреть, какие имена имеют эти переменные на вашем графике, и подать эти имена (эти имена вряд ли изменятся, поскольку они находятся в контрольных точках, и изменение этих имен нарушит совместимость контрольных точек).

В качестве альтернативы, вы можете сделать копию BasicLSTMCell, которая выставляет переменные. Это самый чистый подход, я думаю.

Вы можете использовать строку ниже, чтобы получить переменные на графике

variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

Затем вы можете проверить эти переменные, чтобы увидеть, как они меняются

Другие вопросы по тегам