ONNX с пользовательскими операциями от TensorFlow в Java

чтобы использовать машинное обучение в Java, я пытаюсь обучить модель в TensorFlow, сохранить ее как файл ONNX, а затем использовать файл для вывода в Java. Хотя это прекрасно работает с простыми моделями, с использованием слоев предварительной обработки все становится сложнее, поскольку они, похоже, зависят от пользовательских операторов.

https://www.tensorflow.org/tutorials/keras/text_classification

Например, этот Colab занимается классификацией текста и использует слой TextVectorization следующим образом:

      @tf.keras.utils.register_keras_serializable()
def custom_standardization2(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, '<br />',' ')
    return tf.strings.regex_replace(stripped_html, '[%s]' % re.escape(string.punctuation), '')


vectorize_layer = layers.TextVectorization(
    standardize=custom_standardization2,
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=sequence_length
)

Он используется в качестве слоя предварительной обработки в скомпилированной модели:

      export_model = tf.keras.Sequential([
    vectorize_layer,
    model,
    layers.Activation('sigmoid')
])

export_model.compile(loss=losses.BinaryCrossentropy(from_logits=False), optimizer="adam", metrics=['accuracy'])

Чтобы создать файл ONNX, я сохраняю модель как protobuf, а затем конвертирую ее в ONNX:

      export_model.save("saved_model")

python -m tf2onnx.convert --saved-model saved_model --output saved_model.onnx --extra_opset ai.onnx.contrib:1 --opset 11

Используя onnxruntime-extensions , теперь можно зарегистрировать пользовательские операции и запустить модель в Python для вывода.

      import onnxruntime
from onnxruntime import InferenceSession
from onnxruntime_extensions import get_library_path

so = onnxruntime.SessionOptions()
so.register_custom_ops_library(get_library_path())

session = InferenceSession('saved_model.onnx', so)
res = session.run(None, { 'text_vectorization_2_input': example_new })

Это поднимает вопрос, возможно ли использовать ту же модель в Java аналогичным образом. Onnxruntime для Java имеет функцию SessionOptions#registerCustomOpLibrary , поэтому я подумал о чем-то вроде этого:

      OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.registerCustomOpLibrary(""); // reference the library
OrtSession session = env.createSession("...", options);

Есть ли у кого-нибудь идеи, возможен ли описанный вариант использования или как использовать модели со слоями предварительной обработки в Java (без использования TensorFlow Java)?

1 ответ

Решение, которое вы предлагаете в своем обновлении, правильное, вам нужно скомпилировать пакет расширения среды выполнения ONNX из исходного кода, чтобы получить dll/so/dylib, а затем вы можете загрузить его в среду выполнения ONNX на Java, используя параметры сеанса. Python whl не распространяет двоичный файл в формате, который можно загрузить вне Python, поэтому единственным вариантом является компиляция из исходного кода. Я написал Java API ONNX Runtime, поэтому, если этот подход не сработает, создайте проблему на Github, и мы ее исправим.

Другие вопросы по тегам