Ошибка при сохранении и использовании модели TensorForestEstimator для Android

Я использую оценщик randomforest, реализованный в тензорном потоке, чтобы предсказать, является ли текст английским или нет. Я сохранил свою модель (набор данных с 2 тыс. Образцов и 2 метками классов 0/1 (не на английском / английском языке)), используя следующий код (функция train_input_fn возвращает функции и метки классов):

model_path='test/'
TensorForestEstimator(params, model_dir='model/')
estimator.fit(input_fn=train_input_fn, max_steps=1)

После запуска приведенного выше кода graph.pbtxt и контрольные точки сохраняются в папке модели. Теперь я хочу использовать его на Android. У меня 2 проблемы:

  1. В качестве первого шага мне нужно заморозить график и контрольные точки в файл.pb, чтобы использовать его на Android. Я попробовал freeze_graph (я использовал код здесь: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). Когда я вызываю freeze_graph в моем режиме, я получаю следующую ошибку, и код не может создать окончательный граф.pb:

    Файл "/Users/XXXXXXX/freeze_graph.py", строка 105, в freeze_graph _ = tf.import_graph_def(input_graph_def, name="") Файл "/anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", строка 258, в import_graph_def op_def = op_dict[node.op] KeyError: u'CountExtremelyRandomStats'

вот как я называю freeze_graph:

def save_model_android():
    checkpoint_state_name = "model.ckpt-1"
    input_graph_name = "graph.pbtxt"
    output_graph_name = "output_graph.pb"
    checkpoint_path = os.path.join(model_path, checkpoint_state_name)

    input_graph_path = os.path.join(model_path, input_graph_name)
    input_saver_def_path = None
    input_binary = False
    output_node_names = "output"
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(model_path, output_graph_name)
    clear_devices = True

    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                              input_binary, checkpoint_path,
                              output_node_names, restore_op_name,
                              filename_tensor_name, output_graph_path,
                              clear_devices, "")

Я также попытался заморозить набор данных радужной оболочки в "tf.contrib.learn.datasets.load_iris". Я получаю ту же ошибку. Поэтому я считаю, что это не связано с набором данных.

  1. В качестве второго шага мне нужно использовать файл.pb на телефоне, чтобы предсказать текст. Я нашел демонстрационный пример камеры от Google, и он содержит много кода. Интересно, есть ли пошаговое руководство по использованию модели Tensorflow на Android, передавая вектор объектов и получая метку класса.

Заранее спасибо!

ОБНОВИТЬ

Используя последнюю версию tenorflow (0.12), проблема решена. Однако теперь проблема в том, что я должен передать output_node_names??? Как я могу получить, каковы выходные узлы на графике?

2 ответа

В отношении (1) похоже, что вы запускаете freeze_graph для сборки tenorflow, у которой нет доступа к contrib ops. Может быть, попробуйте явно импортировать тензорный лес перед вызовом freeze_graph?

Re (2) я не знаю более простой пример.

CountExtremelyRandomStats является одним из пользовательских операций TensorForest и существует в tenorflow/contrib. Как было указано, TF в какой-то момент переключился на включение операций по умолчанию. Я не думаю, что есть простой способ включить пользовательские операции contrib в глобальный реестр в предыдущих выпусках, потому что TensorForest использует метод создания файла.so, который включен как файл данных, который загружается во время выполнения (метод это был стандарт при создании TensorForest, но, возможно, он больше не будет). Таким образом, нет легко включаемых правил сборки Python, которые будут правильно ссылаться в пользовательских операциях C++. Вы можете попробовать включить tensflow/contrib/tenor_forest:ops_lib в качестве dep в ваше правило сборки, но я не думаю, что это сработает.

В любом случае вы можете попробовать установить ночную сборку tenorflow. Альтернатива включает изменение способа создания пользовательских тензорных лесов, что довольно неприятно.

Другие вопросы по тегам