python - чтение тензорного потока
Я следую официальному руководству TenorFlow Cifar, чтобы создать входной конвейер для моего набора данных изображений. После обучения модели я решил использовать для теста 50000 изображений. Однако после тестирования я обнаружил, что многие изображения тестируются несколько раз, в то время как некоторые изображения вообще не тестируются. Например, одно изображение с именем '1000_left' тестируется несколько раз, в то время как некоторые изображения, такие как '100_left', вообще не тестируются. Может ли кто-нибудь помочь мне определить, что происходит? Спасибо!
Вот как я загружаю набор данных tfrecord и сгенерированные пакеты изображений для теста:
def _generate_image_and_name_batch(image, name, min_queue_examples, batch_size):
num_preprocess_threads = 16
images, names = tf.train.batch([image, name], batch_size = batch_size, num_threads = num_preprocess_threads,
capacity = min_queue_examples + 3 * batch_size)
tf.summary.image('images', images)
return images, tf.reshape(names, [batch_size])
def read_test(test_dir, batch_size):
#constructs inputs for test
if not tf.gfile.Exists(test_dir):
raise ValueError("Failed to find file: " + test_dir)
#restore features of test record
features = {"test/image": tf.FixedLenFeature([], tf.string),
"test/name": tf.FixedLenFeature([], tf.string)}
with tf.name_scope("test_input"):
filename_queue = tf.train.string_input_producer(string_tensor = [test_dir])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = features)
image = tf.decode_raw(features['test/image'], tf.float32)
image = tf.reshape(image, [224,224,3])
name = tf.cast(features['test/name'], tf.string)
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TEST * min_fraction_of_examples_in_queue)
return _generate_image_and_name_batch(image, name, min_queue_examples, batch_size)
Вот как я загружаю данные для теста:
#start queue runners
coord = tf.train.Coordinator()
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord = coord, daemon=True, start= True))
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
print("Num iterations for total:", num_iter)
step = 0
image_names = []
all_predictions = []
while step < num_iter and not coord.should_stop():
predictions = sess.run([top_1_op])[0]
img_name = sess.run(name)
all_predictions = np.concatenate([all_predictions, predictions])
image_names = np.concatenate([image_names, img_name])
step += 1
if step % 100 == 0 or step + 1 == num_iter:
print("Test step {} has finished".format(step))
except Exception as e:
coord.request_stop(e)
coord.request_stop()
coord.join(threads, stop_grace_period_secs= 10)
Я в основном следовал тем же шагам и коду, что и в примерах cifar. Пожалуйста помоги! Спасибо!