Ошибка tf.Estimator.predict() при использовании модуля Tensorflow Hub в качестве основы пользовательского tf.Estimator
Я пытаюсь создать пользовательский тензор потока tf.Estimator. В model_fn, переданном в tf.Estimator, я импортирую модуль Inception_V3 из Tensorflow Hub.
Проблема: После тонкой настройки модели (с использованием tf.Estimator.train) результаты, полученные с использованием tf.Estimator.predict, не так хороши, как ожидалось, на основании tf.Estimator.evaluate (это для проблемы регрессии.)
Я новичок в Tensorflow и Tensorflow Hub, так что я мог делать много ошибок новичка.
Когда я запускаю tf.Estimator.evaluate () для моих данных проверки, заявленная потеря находится в том же парке шаров, что и потеря после того, как tf.Estimator.train () был использован для обучения модели. Проблема возникает, когда я пытаюсь использовать tf.Estimator.predict () для тех же данных проверки.
tf.Estimator.predict () возвращает прогнозы, которые я затем использую для вычисления той же метрики потерь (mean_squared_error), которая вычисляется с помощью tf.Estimator.evaluate (). Я использую тот же набор данных для подачи в функцию прогнозирования, что и функция оценки. Но я не получаю тот же результат для mean_squared_error - не удаленно близко! (Mse, который я вычисляю из предсказания, намного хуже.)
Вот что я сделал (отредактировал некоторые детали)... Определите model_fn с модулем Tensorflow Hub. Затем вызовите функции tf.Estimator для обучения, оценки и прогнозирования.
def my_model_fun(features, labels, mode, params):
# Load InceptionV3 Module from Tensorflow Hub
iv3_module =hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1",trainable=True, tags={'train'})
# Gather the variables for fine-tuning
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='CustomeLayer')
var_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='module/InceptionV3/Mixed_5b'))
predictions = {"the_prediction" : final_output}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Define loss, optimizer, and evaluation metrics
loss = tf.losses.mean_squared_error(labels=labels, predictions=final_output)
optimizer =tf.train.AdadeltaOptimizer(learning_rate=learn_rate).minimize(loss,
var_list=var_list, global_step=tf.train.get_global_step())
rms_error = tf.metrics.root_mean_squared_error(labels=labels,predictions=predictions["the_prediction"])
eval_metric_ops = {"rms_error": rms_error}
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(mode=mode, loss=loss,train_op=optimizer)
if mode == tf.estimator.ModeKeys.EVAL:
tf.summary.scalar('rms_error', rms_error)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss,eval_metric_ops=eval_metric_ops)
iv3_estimator = tf.estimator.Estimator(model_fn=iv3_model_fn)
iv3_estimator.train(input_fn=train_input_fn, steps=TRAIN_STEPS)
iv3_estimator.evaluate(input_fn=val_input_fn)
ii =0
for ans in iv3_estimator.predict(input_fn=test_input_fn):
sqErr = np.square(label[ii] - ans['the_prediction'][0])
totalSqErr += sqErr
ii += 1
mse = totalSqErr/ii
Я ожидаю, что потеря mse, сообщаемая tf.Estimator.evaluate (), должна быть такой же, как когда я вычисляю mse по известным меткам и выводу tf.Estimator.predict ()
Нужно ли импортировать модель Tensorflow Hub по-разному, когда я использую прогнозирование? (используйте trainable=False в вызове hub.Module()?
Используются ли веса, полученные в результате обучения, при запуске tf.Estimator.evaluate (), но не при запуске tf.Estimator.predict ()?
Другой?
1 ответ
Есть несколько вещей, которые, кажется, отсутствуют в фрагменте кода. Как final_output
вычислено из iv3_module
? Кроме того, среднеквадратическая ошибка является необычным выбором функции потерь для задачи классификации; общий подход заключается в передаче элементов изображения из модуля в линейный выходной слой с оценками для каждого класса ("логиты") и "кросс-энтропийной потерей softmax". Для объяснения этих терминов вы можете ознакомиться с онлайн-учебниками, такими как https://developers.google.com/machine-learning/crash-course/ (вплоть до многоклассовых нейронных сетей).
Относительно технических характеристик TF-Hub:
- Переменные модуля-концентратора автоматически добавляются в коллекции GLOBAL_VARIABLES и TRAINABLE_VARIABLES (если
trainable=True
, как вы уже делаете). Никакого ручного расширения этих коллекций не требуется. hub.Module(..., tags=...)
должен быть установлен в{"train"}
заmode==TRAIN
и установить вNone
или пустой набор в противном случае.
В общем случае полезно получить решение, работающее сквозным образом для вашей проблемы без точной настройки в качестве базовой линии, а затем добавить точную настройку.