Использование обучения, сделанного с использованием Python API, в качестве входных данных для модуля LabelImage в API Java?

У меня есть проблема с Java API Tenorsflow. Я запустил обучение с использованием API-интерфейса Тензор потока Python, сгенерировав файлы output_graph.pb и output_labels.txt. Теперь по какой-то причине я хочу использовать эти файлы в качестве входных данных для модуля LabelImage в API-интерфейсе Java. Я думал, что все будет работать нормально, так как этот модуль хочет ровно один.pb и один.txt. Тем не менее, когда я запускаю модуль, я получаю эту ошибку:

2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:343)
at org.tensorflow.Session$Runner.feed(Session.java:137)
at org.tensorflow.Session$Runner.feed(Session.java:126)
at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115)
at it.zero11.LabelImage.main(LabelImage.java:68)

Буду очень признателен, если вы поможете мне найти причину проблемы. Кроме того, я хочу спросить вас, есть ли способ запустить обучение из API JavaSpenSource, потому что это упростит задачу.

Чтобы быть более точным:

На самом деле, я не использую самописный код, по крайней мере, для соответствующих шагов. Все, что я сделал, - это тренируюсь с этим модулем, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py, снабжая его каталогом, содержащим изображения, разделенные между подкаталогами согласно их описанию. В частности, я думаю, что это строки, которые генерируют выходные данные:

output_graph_def = graph_util.convert_variables_to_constants(
    sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
  f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
  f.write('\n'.join(image_lists.keys()) + '\n')

Затем я даю выходные данные (один some_graph.pb и один some_labels.txt) в качестве входных данных для этого Java-модуля: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java, заменяющий ввод по умолчанию. Я получаю ошибку, о которой сообщалось выше.

2 ответа

Модель, используемая по умолчанию в LabelImage.java, отличается от модели, которая переобучается, поэтому имена входных и выходных узлов не совпадают. Обратите внимание, что модели TensorFlow представляют собой графики и аргументы feed() а также fetch() являются названиями узлов в графе. Так что вам нужно знать названия, подходящие для вашей модели.

Смотря на retrain.py, кажется, что он имеет узел, который принимает сырое содержимое файла JPEG в качестве входных данных (узел DecodeJpeg/contents) и производит набор меток в узле final_result,

Если это так, то вы должны сделать что-то вроде следующего в Java (и вам не нужен бит, который строит график для нормализации изображения, так как это кажется частью переобученной модели, поэтому замените LabelImage.java:64 с чем-то вроде:

try (Tensor image = Tensor.create(imageBytes);
     Graph g = new Graph()) {
  g.importGraphDef(graphDef);
  try (Session s = new Session(g);
    // Note the change to the name of the node and the fact
    // that it is being provided the raw imageBytes as input
    Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) {
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
      throw new RuntimeException(
          String.format(
              "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
              Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    float[] probabilities = result.copyTo(new float[1][nlabels])[0];
    // At this point nlabels = number of classes in your retrained model
    DoSomethingWith(probabilities);
  }
}

Надеюсь, это поможет.

Что касается ошибки "Нет операции", я смог устранить ее, используя имена входного и выходного слоев "Mul" и "final_result" соответственно. Увидеть:

https://github.com/tensorflow/tensorflow/issues/2883

Другие вопросы по тегам