ONNX с пользовательскими операциями от TensorFlow в Java
чтобы использовать машинное обучение в Java, я пытаюсь обучить модель в TensorFlow, сохранить ее как файл ONNX, а затем использовать файл для вывода в Java. Хотя это прекрасно работает с простыми моделями, с использованием слоев предварительной обработки все становится сложнее, поскольку они, похоже, зависят от пользовательских операторов.
https://www.tensorflow.org/tutorials/keras/text_classification
Например, этот Colab занимается классификацией текста и использует слой TextVectorization следующим образом:
@tf.keras.utils.register_keras_serializable()
def custom_standardization2(input_data):
lowercase = tf.strings.lower(input_data)
stripped_html = tf.strings.regex_replace(lowercase, '<br />',' ')
return tf.strings.regex_replace(stripped_html, '[%s]' % re.escape(string.punctuation), '')
vectorize_layer = layers.TextVectorization(
standardize=custom_standardization2,
max_tokens=max_features,
output_mode='int',
output_sequence_length=sequence_length
)
Он используется в качестве слоя предварительной обработки в скомпилированной модели:
export_model = tf.keras.Sequential([
vectorize_layer,
model,
layers.Activation('sigmoid')
])
export_model.compile(loss=losses.BinaryCrossentropy(from_logits=False), optimizer="adam", metrics=['accuracy'])
Чтобы создать файл ONNX, я сохраняю модель как protobuf, а затем конвертирую ее в ONNX:
export_model.save("saved_model")
python -m tf2onnx.convert --saved-model saved_model --output saved_model.onnx --extra_opset ai.onnx.contrib:1 --opset 11
Используя onnxruntime-extensions , теперь можно зарегистрировать пользовательские операции и запустить модель в Python для вывода.
import onnxruntime
from onnxruntime import InferenceSession
from onnxruntime_extensions import get_library_path
so = onnxruntime.SessionOptions()
so.register_custom_ops_library(get_library_path())
session = InferenceSession('saved_model.onnx', so)
res = session.run(None, { 'text_vectorization_2_input': example_new })
Это поднимает вопрос, возможно ли использовать ту же модель в Java аналогичным образом. Onnxruntime для Java имеет функцию SessionOptions#registerCustomOpLibrary , поэтому я подумал о чем-то вроде этого:
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.registerCustomOpLibrary(""); // reference the library
OrtSession session = env.createSession("...", options);
Есть ли у кого-нибудь идеи, возможен ли описанный вариант использования или как использовать модели со слоями предварительной обработки в Java (без использования TensorFlow Java)?
1 ответ
Решение, которое вы предлагаете в своем обновлении, правильное, вам нужно скомпилировать пакет расширения среды выполнения ONNX из исходного кода, чтобы получить dll/so/dylib, а затем вы можете загрузить его в среду выполнения ONNX на Java, используя параметры сеанса. Python whl не распространяет двоичный файл в формате, который можно загрузить вне Python, поэтому единственным вариантом является компиляция из исходного кода. Я написал Java API ONNX Runtime, поэтому, если этот подход не сработает, создайте проблему на Github, и мы ее исправим.