Интерактивный прогноз Tensorflow с использованием API набора данных

Я создал модель Tensorflow, которая использует API набора данных для подачи данных в сеть.

После фазы обучения я хотел бы восстановить эту модель и время от времени делать на ней выводы.

В настоящее время я каждый раз инициализирую итератор набора данных, но мне интересно, есть ли альтернативный способ. Более того, во время обучения мой набор данных содержит данные x и y, а во время прогнозирования у меня есть только x. В качестве временного решения я предоставляю фальшивку, но опять же, это не кажется лучшим решением.

Вот псевдокод того, что я делаю:

#### NETWORK
input_x = tf.placeholder(tf.int32, [None, None], name="input_x")
input_y = tf.placeholder(tf.int32, [None, 2], name="input_y")

dataset = tf.data.Dataset.from_tensor_slices((input_x, input_y))
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

x_data, y_data = iterator.get_next()
output = tf.variable(x_data, name='output')

.....

### INFERENCE

while (true):
    x = new_input

    x_operation = session.graph.get_operation_by_name("input_x").outputs[0]
    y_operation = session.graph.get_operation_by_name("input_y").outputs[0]
    dataset_operation = session.graph.get_operation_by_name("dataset_init")
    output_operation = session.graph.get_operation_by_name("output").outputs[0]

    fake_y = np.array([[0, 0]])

    dic = {input_x: x, input_y: y}
    session.run(dataset_operation, feed_dict=dic)

    prediction = session.run(output_operation)

Спасибо за помощь

0 ответов

Другие вопросы по тегам