Как использовать питаемый итератор из Tensorflow Dataset API вместе с MonitoredTrainingSession?
Руководство программиста Tensorflow рекомендует использовать подающий итератор для переключения между набором данных обучения и проверки без повторной инициализации итератора. Это главным образом требует подачи ручки, чтобы выбрать между ними.
Как использовать его вместе с tf.train.MonitoredTrainingSession?
Следующий метод завершается с ошибкой "RuntimeError: График завершен и не может быть изменен". ошибка.
with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
Как добиться одновременно удобства MonitoredTrainingSession и итеративных наборов данных для обучения и проверки?
3 ответа
Я получил ответ от проблемы Tensorflow GitHub - https://github.com/tensorflow/tensorflow/issues/12859
Решение состоит в том, чтобы вызвать iterator.string_handle()
перед созданием MonitoredSession
,
import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()
with tf.train.MonitoredTrainingSession() as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
print('train', sess.run(next_batch, feed_dict={handle: handle_train}))
if step % 3 == 0:
print('val', sess.run(next_batch, feed_dict={handle: handle_val}))
Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)
@ Майкл Джейсон G ответ правильный. Однако это не работает, когда вы также хотите использовать определенные session_run_hooks, которые должны оценивать части графика, такие как, например, LoggingTensorHook или SummarySaverHook. Пример ниже приведет к ошибке:
import tensorflow as tf
dataset_train = tf.data.Dataset.range(10)
dataset_val = tf.data.Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
feature = iterator.get_next()
pred = feature * feature
tf.summary.scalar('pred', pred)
global_step = tf.train.create_global_step()
summary_hook = tf.train.SummarySaverHook(save_steps=5,
output_dir="summaries", summary_op=tf.summary.merge_all())
with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
feat = sess.run(feature, feed_dict={handle: handle_train})
pred_ = sess.run(pred, feed_dict={handle: handle_train})
print('train: ', feat)
print('pred: ', pred_)
if step % 3 == 0:
print('val', sess.run(feature, feed_dict={handle: handle_val}))
Это не удастся с ошибкой:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string
[[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
[[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
Причина в том, что ловушка будет пытаться вычислить график уже после первого session.run([iter_train_handle, iter_val_handle]), который, очевидно, еще не содержит дескриптор в feed_dict.
Обходное решение состоит в том, чтобы перезаписать перехватчики, вызывающие проблему, и изменить код в before_run и after_run для оценки только по вызовам session.run, содержащим дескриптор в feed_dict (вы можете получить доступ к feed_dict текущего вызова session.run через run_context аргумент before_run и after_run)
Или вы можете использовать последний мастер Tensorflow (post-1.4), который добавляет в MonitoredSession функцию run_step_fn, которая позволяет вам указать следующий шаг step_fn, который позволит избежать ошибки (за счет оценки количества выражений оператора if TrainingIstruction несколько раз...)
def step_fn(step_context):
if handle_train is None:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
return step_context.run_with_hooks(fetches=..., feed_dict=...)
Существует демо для использования заполнителя в mot_session с SessionRunHook. Эта демонстрация о переключении наборов данных путем подачи diff handle_string.
Кстати, я перепробовал все решения, но только это работает.