Как я могу получить доступ к весам повторяющейся ячейки в Tensorflow?
Одним из способов повышения стабильности в задачах глубокого Q-обучения является поддержание набора целевых весов для сети, которые обновляются медленно и используются для расчета целевых значений Q-значения. В результате в разное время в процедуре обучения в прямом проходе используются два разных набора весов. Для обычного DQN это не сложно реализовать, так как весовые коэффициенты являются переменными тензорного потока, которые могут быть установлены в feed_dict, т.е.
sess = tf.Session()
input = tf.placeholder(tf.float32, shape=[None, 5])
weights = tf.Variable(tf.random_normal(shape=[5,4], stddev=0.1)
bias = tf.Variable(tf.constant(0.1, shape=[4])
output = tf.matmul(input, weights) + bias
target = tf.placeholder(tf.float32, [None, 4])
loss = ...
...
#Here we explicitly set weights to be the slowly updated target weights
sess.run(output, feed_dict={input: states, weights: target_weights, bias: target_bias})
# Targets for the learning procedure are computed using this output.
....
#Now we run the learning procedure, using the most up to date weights,
#as well as the previously computed targets
sess.run(loss, feed_dict={input: states, target: targets})
Я хотел бы использовать эту целевую сетевую технику в рекуррентной версии DQN, но я не знаю, как получить доступ и установить веса, используемые внутри рекуррентной ячейки. В частности, я использую tf.nn.rnn_cell.BasicLSTMCell, но я хотел бы знать, как это сделать для любого типа рекуррентной ячейки.
2 ответа
BasicLSTMCell не предоставляет свои переменные как часть своего открытого API. Я рекомендую вам либо посмотреть, какие имена имеют эти переменные на вашем графике, и подать эти имена (эти имена вряд ли изменятся, поскольку они находятся в контрольных точках, и изменение этих имен нарушит совместимость контрольных точек).
В качестве альтернативы, вы можете сделать копию BasicLSTMCell, которая выставляет переменные. Это самый чистый подход, я думаю.
Вы можете использовать строку ниже, чтобы получить переменные на графике
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
Затем вы можете проверить эти переменные, чтобы увидеть, как они меняются