Использование созданной тензорной модели для прогнозирования

Я смотрю на исходный код из этой статьи Tensorflow, в которой рассказывается о том, как создать широкую и глубокую модель обучения. https://www.tensorflow.org/versions/r1.3/tutorials/wide_and_deep

Вот ссылка на исходный код Python: https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py

Цель этого состоит в том, чтобы подготовить модель, которая будет предсказывать, зарабатывает ли кто-то больше или меньше 50 тысяч долларов в год, учитывая данные переписи.

Как указано, я запускаю эту команду для выполнения:

python wide_n_deep_tutorial.py --model_type=wide_n_deep

В результате я получаю следующее:

$ python wide_n_deep.py --model_type=wide_n_deep
Training data is downloaded to /tmp/tmp_pwqo2h8
Test data is downloaded to /tmp/tmph6jcimik
2018-01-03 05:34:12.236038: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
WARNING:tensorflow:enqueue_data was called with num_epochs and num_threads > 1. num_epochs is applied per thread, so this will produce more epochs than you probably intend. If you want to limit epochs, use one thread.
WARNING:tensorflow:enqueue_data was called with shuffle=False and num_threads > 1. This will create multiple threads, all reading the array/dataframe in order. If you want examples read in order, use one thread; if you want multiple threads, enable shuffling.
WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool.
WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool.
model directory = /tmp/tmp_ab6cfsf
accuracy: 0.808673
accuracy_baseline: 0.763774
auc: 0.841373
auc_precision_recall: 0.66043
average_loss: 0.418642
global_step: 2000
label/mean: 0.236226
loss: 41.8154
prediction/mean: 0.251593

В различных статьях, которые я видел в Интернете, говорится о загрузке в .ckpt файл. Когда я смотрю в каталог моей модели, я вижу эти файлы:

$ ls /tmp/tmp_ab6cfsf
checkpoint  eval  events.out.tfevents.1514957651.ml-1  graph.pbtxt  model.ckpt-1.data-00000-of-00001  model.ckpt-1.index  model.ckpt-1.meta  model.ckpt-2000.data-00000-of-00001  model.ckpt-2000.index  model.ckpt-2000.meta

Я предполагаю, что я бы использовал model.ckpt-1.meta, это верно?

Но я также не совсем понимаю, как использовать и подавать данные этой модели. Я посмотрел эту статью на веб-сайте Tensorflow: https://www.tensorflow.org/versions/r1.3/programmers_guide/saved_model

Что говорит: "Обратите внимание, что Оценщики автоматически сохраняют и восстанавливают переменные (в model_dir)". (не уверен, что это значит в этом контексте)

Как я могу генерировать информацию в формате данных переписи, кроме зарплаты, поскольку это то, что мы должны прогнозировать? Для меня не очевидно, как использовать две статьи Tensorflow, чтобы иметь возможность использовать обученную модель для прогнозирования.

1 ответ

Решение

Вы можете посмотреть официальные сообщения в блоге ( часть 1 и часть 3) от команды TensorFlow, которые хорошо объясняют, как использовать оценщик.

В частности, они объясняют, как делать прогнозы, используя пользовательский ввод. Это использует встроенный predict Метод оценки:

estimator = tf.estimator.Estimator(model_fn, ...)

predict_input_fn = ...  # define this using tf.data

predict_results = estimator.predict(predict_input_fn)
for idx, prediction in enumerate(predict_results):
    print(idx)
    for key in prediction:
        print("...{}: {}".format(key, prediction[key]))

Для вашего примера мы можем создать функцию прогнозирования ввода, используя дополнительный файл CSV. Давайте предположим, что у нас есть CSV-файл "predict.csv" содержащий три примера (могут быть первые три строки "test.csv" например без ярлыков). Это дало бы:

predict.csv:

... пропустить эту строку...
25, Частный, 226802, 11, 7, Никогда не был женат, Machine-op-insct, Собственный ребенок, Черный, Мужской, 0, 0, 40, Соединенные Штаты
38, Частное лицо, 89814, HS-grad, 9, Супружеская супруга, Сельское хозяйство-рыбалка, Муж, Белый, Мужской, 0, 0, 50, Соединенные Штаты
28, Local-gov, 336951, Assoc-acdm, 12, Супружеская супруга, Защитник, Муж, Белый, Мужской, 0, 0, 40, Соединенные Штаты

estimator = build_estimator(FLAGS.model_dir, FLAGS.model_type)

def predict_input_fn(data_file):
    """Input builder function."""
    df_data = pd.read_csv(
        tf.gfile.Open(data_file),
        names=CSV_COLUMNS[:-1],  # remove the last name "income_bracket" that corresponds to the label
        skipinitialspace=True,
        engine="python",
        skiprows=1)
    # remove NaN elements
    df_data = df_data.dropna(how="any", axis=0)
    return tf.estimator.inputs.pandas_input_fn(x=df_data, y=None, shuffle=False)

predict_file_name = "wide_n_deep/predict.csv"
predict_results = estimator.predict(input_fn=predict_input_fn(predict_file_name))
for idx, prediction in enumerate(predict_results):
    print(idx)
    for key in prediction:
        print("...{}: {}".format(key, prediction[key]))
Другие вопросы по тегам