Проблемы, с которыми сталкиваются при ручной реализации FIFOQueue в tenorflow
Я пытаюсь придумать метод, который мог бы реализовать FIFOQueue
в тензорном потоке. Таким образом, на каждой итерации цель состоит в том, чтобы назначить placeholder
определенное число, а затем сохранить его в Variable
по имени: буфер. После каждого назначения я увеличиваю индекс. Размер буфера равен [5], поэтому индекс должен находиться в диапазоне от 0 до 4. Наконец, после заполнения буфера я бы установил buffer[0:4]
быть buffer[1:5]
и затем добавьте новое значение в buffer[4]
, Так вот мой
код:
import tensorflow as tf
import numpy as np
import random
dim = 30
lst = []
for i in range(dim):
lst.append(random.randint(1, 10))
data = np.reshape(lst, [dim, 1])
print(lst)
# create a buffer:
buffer_input = tf.placeholder(tf.int32, shape=[1])
buffer = tf.Variable(tf.zeros([5], tf.int32))
index = tf.Variable(tf.constant(0))
def fillBufferBeforeFilled():
update_op1 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)
index_assign_add = tf.assign_add(index, 1)
return update_op1, index_assign_add
def fillBufferAfterFilled():
tmp = tf.slice(buffer, begin=[0], size=[4])
update_op2 = tf.scatter_update(buffer, indices=[0, 1, 2, 3], updates=tmp)
update_op3 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)
return update_op2, update_op3
cond = tf.cond(tf.equal(index, 4), lambda: fillBufferBeforeFilled(), lambda: fillBufferAfterFilled())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(dim):
cond_ = sess.run(cond, feed_dict={buffer_input: data[i]})
buf = sess.run(buffer, feed_dict={buffer_input: data[i]})
print('buf: ', buf)
Проблема: index
Переменная не увеличивается после каждого вызова, в то время как первый элемент buffer
присваивается значению, переданному заполнителю.
Я хотел бы знать, почему у меня такое поведение и каково решение этой проблемы.
Любая помощь высоко ценится!!
2 ответа
Вот решение:
import tensorflow as tf
import numpy as np
import random
dim = 30
lst = []
for i in range(dim):
lst.append(random.randint(1, 10))
data = np.reshape(lst, [dim, 1])
print(lst)
# create a buffer:
buffer_input = tf.placeholder(tf.int32, shape=[1])
buffer = tf.Variable(tf.zeros([5], tf.int32))
index = tf.Variable(-1, tf.int32)
def fillBufferBeforeFilled():
index_assign_add = tf.assign_add(index, 1)
with tf.control_dependencies([index_assign_add]):
update_op1 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)
return update_op1, index_assign_add
def fillBufferAfterFilled():
tmp = tf.slice(buffer, begin=[1], size=[4])
update_op2 = tf.scatter_update(buffer, indices=[0, 1, 2, 3], updates=tmp)
with tf.control_dependencies([update_op2]):
update_op3 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)
return update_op2, update_op3
cond = tf.cond(tf.equal(index, 4), lambda: fillBufferAfterFilled(), lambda: fillBufferBeforeFilled())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(dim):
cond_ = sess.run(cond, feed_dict={buffer_input: data[i]})
buf = sess.run(buffer, feed_dict={buffer_input: data[i]})
print('buf: ', buf)
Вы перепутали порядок условий в tf.cond
; так должно быть
cond = tf.cond(tf.equal(index, 4), lambda: fillBufferAfterFilled(), lambda: fillBufferBeforeFilled())
Я могу запустить ваш код, и он в основном работает, но обновления не совсем правильные; Я подозреваю, что вам нужно будет добавить немного tf.control_dependencies
призывает заставить вещи происходить в правильном порядке.