Tensorflow: правильная структура очередей / группирования с использованием набора обучения и проверки
Я пытаюсь повторить структуру, использованную в примере TensorBoard MNIST из недавнего саммита разработчиков 2017 года (код найден здесь). В нем feed_dict используются для чередования обучающих и проверочных наборов; однако они используют очень непрозрачный mnist.train.next_batch, что делает действительно трудным повторение ваших собственных действий.
По общему признанию, это может также быть, потому что я изо всех сил пытаюсь понять реализацию очереди в Tensorflow, и явных примеров, кажется, не хватает, особенно для TF > v1.0.
Я сделал свою собственную попытку классифицировать изображения CNN на основе различных примеров, на которые я наткнулся. Первоначально у меня была работа только с обучающими данными путем хранения данных в предварительно загруженных переменных (это небольшой набор данных). Я предположил, что было бы легче заставить поезд / действующий своп работать через подачу данных из имен файлов, поэтому я попытался изменить его на это.
Между изменением формата и попыткой реализовать структуру feed /dict train / valid я получаю следующее:
Ошибка: "Вы должны передать значение для тензора-заполнителя input/Placeholder_2 со строкой dtype".
Любые советы о том, как заставить это работать или дальнейшие объяснения о том, как на самом деле работают slicer/train.batch/QueueRunner, будут очень полезны, так как я обнаружил, что в учебнике Tensorflow отсутствует объяснение основного рабочего процесса между их.
У меня такое ощущение, что train.batch находится в совершенно неправильном месте, и что он, вероятно, должен быть в файле feed_dict def, но в противном случае я понятия не имею. Спасибо!
import tensorflow as tf
from tensorflow.python.framework import dtypes
# Input - 216x216x1 images; ~900 training images, ~350 validation
# Want to do batches of 5 for training, 20 for validation
learn_rate = .0001
drop_keep = 0.9
train_batch = 5
test_batch = 20
epochs = 1
iterations = int((885/train_batch) * epochs)
#
#
# A BUNCH OF (graph-building) HELPER DEFINITIONS EXCLUDED FOR BREVITY
#
#
#x_init will be fed a list of .jpg filenames (ex: [/file0.jpg, /file1.jpg, ...])
#y_init will be fed an array of one-hot classes (ex: [[0,1,0], [1,0,0], ...])
sess = tf.InteractiveSession()
with tf.name_scope('input'):
batch_size = tf.placeholder(tf.int32)
keep_prob = tf.placeholder(tf.float32)
x_init = tf.placeholder(dtype=tf.string, shape=(None))
y_init = tf.placeholder(dtype=np.int32, shape=(None,3)) #3 classes
image, label = tf.train.slice_input_producer([x_init, y_init])
file = tf.read_file(image)
image = tf.image.decode_jpeg(file, channels=1)
image = tf.cast(image, tf.float32)
image.set_shape([216,216,1])
label = tf.cast(label, tf.int32)
images, labels = tf.train.batch([image, label], batch_size=batch_size)
conv1 = conv_layer(images, [5,5,1], 40, 'conv1')
#
#
# skip the rest of graph defining/functions (merged,train_step)
# very similar to what is found in the MNIST example.
#
#
tf.summary.scalar('accuracy', accuracy)
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(OUTPUT_LOC + '/train',sess.graph)
test_writer = tf.summary.FileWriter(OUTPUT_LOC + '/test')
sess.run(tf.global_variables_initializer())
#xTrain, yTrain, xTest, yTest are the train/valid images/labels lists
def feed_dict(train=True):
if train:
batch = train_batch
keep = drop_keep
xval = xTrain
yval = yTrain
else:
batch = test_batch
keep = 1
xval = xTest
yval = yTest
return({x_init:xval, y_init:yval, batch_size:batch, keep_prob:keep})
#If I run "threads", I get the error. It works up until here.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
#Don't know what works here or what doesn't.
for i in range(iterations):
if i % 10 == 0:
summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
test_writer.add_summary(summary, i)
print('Accuracy at step %s: %s' % (i, acc))
else:
if i % 100 == 99:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True), options=run_options, run_metadata=run_metadata)
train_writer.add_run_metadata(run_metadata, 'step%03d' % i)
train_writer.add_summary(summary, i)
print('Adding run metadata for', i)
else: # Record a summary
summary, _ = sess.run([merged, train_step],feed_dict=feed_dict(True))
train_writer.add_summary(summary, i)
coord.request_stop()
train_writer.close()
test_writer.close()
sess.close()