Как переключиться между обучением и проверкой набора данных с помощью tf.MonitoredTrainingSession?

Я хочу использовать feedable дизайн итератора в API-интерфейсе TensorFlow Dataset, поэтому я могу перейти к проверке данных после некоторых этапов обучения. Но если я переключился на данные проверки, это завершит весь сеанс.

Следующий код демонстрирует, что я хочу сделать:

import tensorflow as tf


graph = tf.Graph()
with graph.as_default():
    training_ds = tf.data.Dataset.range(32).batch(4)
    validation_ds = tf.data.Dataset.range(8).batch(4)

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_ds.output_types, training_ds.output_shapes)
    next_element = iterator.get_next()

    training_iterator = training_ds.make_initializable_iterator()
    validation_iterator = validation_ds.make_initializable_iterator()


with graph.as_default():

    with tf.train.MonitoredTrainingSession() as sess:
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())
        sess.run(training_iterator.initializer)
        count_training = 0
        while not sess.should_stop():
            x = sess.run(next_element, feed_dict={handle: training_handle})
            count_training += 1
            print('{} [training] {}'.format(count_training, x.shape))
            # print(x)

            # we do periodic validation
            if count_training % 4 == 0:
                sess.run(validation_iterator.initializer)
                count_validation = 0
                while not sess.should_stop():
                    y = sess.run(next_element, feed_dict={handle: validation_handle})
                    count_validation += 1
                    print('  {} [validation] {}'.format(count_validation, y.shape))
                    # print(y)

Учебные данные содержат 32 элемента, по 4 из них, поэтому по 8 пакетов мы проводим проверку каждые 4 шага, поэтому я ожидаю:

#  1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
#      1 [validation]
#      2 [validation]
# 5 [training]
# 6 [training]
# 7 [training]
# 8 [training]
#      1 [validation]
#      2 [validation]

но он останавливается, когда первая проверка сделана:

# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
#      1 [validation]
#      2 [validation]

Итак, как использовать это feedable итератор в tf.MonitoredTrainingSession?

1 ответ

Решение

Я бы предложил поймать tf.errors.OutOfRangeError поднят в конце набора данных проверки (вы также можете проверить раздел обработки нескольких эпох в официальном API для другого решения, используя repeat набор данных):

while not sess.should_stop():
    x = sess.run(next_element, feed_dict={handle: training_handle})
    count_training += 1
    print('{} [training] {}'.format(count_training, x.shape))

    # we do periodic validation
    if count_training % 4 == 0:
        sess.run(validation_iterator.initializer)
        count_validation = 0
        while True:
            try:
                y = sess.run(next_element, feed_dict={handle: validation_handle})
                count_validation += 1
                print('  {} [validation] {}'.format(count_validation, y.shape))
            except tf.errors.OutOfRangeError:
                break

Этот кусок кода печатает:

1 [training] (4,)  
2 [training] (4,)  
3 [training] (4,)  
4 [training] (4,)  
  1 [validation] (4,)  
  2 [validation] (4,)  
5 [training] (4,)
6 [training] (4,)
7 [training] (4,)
8 [training] (4,)
  1 [validation] (4,)
  2 [validation] (4,)
Другие вопросы по тегам