Чтение TFRecords с помощью API tf.data.Dataset увеличивает время вычислений

Мои данные в tfrecords файл. Этот простой код повторяет и объединяет изображения с tf.data.Dataset апи. Тем не менее, время вычислений на 100 пакетов увеличивается. Почему это так и как это исправить?

import tensorflow as tf
import time
sess = tf.Session()
dataset = tf.data.TFRecordDataset('/tmp/data/train.tfrecords')
dataset = dataset.repeat()
dataset = dataset.batch(3)
iterator = dataset.make_one_shot_iterator()

prev_step = time.time()
for step in range(10000):
    tensors = iterator.get_next()
    fetches = sess.run(tensors)
    if step % 200 == 0:
        print("Step %6i time since last %7.5f" % (step, time.time() - prev_step))
        prev_step = time.time()

Это выводит следующее время:

Step      0 time since last 0.01432
Step    200 time since last 1.85303
Step    400 time since last 2.15448
Step    600 time since last 2.65473
Step    800 time since last 3.15646
Step   1000 time since last 3.72434
Step   1200 time since last 4.34447
Step   1400 time since last 5.11210
Step   1600 time since last 5.87102
Step   1800 time since last 6.61459
Step   2000 time since last 7.57238
Step   2200 time since last 8.33060
Step   2400 time since last 9.37795      

Файл tfrecords содержит изображения MNIST, написанные с помощью этого HowTo из документа Tensorflow.

Чтобы сузить масштабы проблемы, я воспроизвел код для чтения необработанных изображений с диска. В этом случае время на 200 партий остается постоянным, как и ожидалось.

Теперь мой вопрос:

  • Какая часть кода увеличивает время вычислений?
  • Должен ли я подать это как ошибку в GitHub Tensorflow?

Решено!

Ответ на мой вопрос: переезд get_next() вне петли

1 ответ

Решение

Решено: Переместить get_next() вне петли

Другие вопросы по тегам