Ошибка при прогнозировании с помощью python onnxruntime

Я создал очень простое дерево решений, используя sklearnбиблиотека. Это дерево обучается на основе 4 функций:

feat1 INT
feat2 INT
feat3 FLOAT
feat4 FLOAT

А метка / целевая функция - это логическое значение (0 или 1).

Я превратил дерево в ONNX формат, и теперь я хочу использовать onnxruntime pythonбиблиотека, чтобы сделать прогноз. Я нашел в Интернете пример кода для этого. Проблема в том, что я не понимаю, что именно происходит во всех частях этого кода, функций и параметров. Это приводит к тому, что я получаю сообщение об ошибке. Я искал документацию, но не могу ее найти.

В приведенном ниже коде я преобразовываю модель дерева в ONNXформат. Это успешно, но части кода я не понимаю. вinitial_typeпеременная, что мне нужно ввести здесь на основе 4 столбцов функций и метки / целевой функции, которую я использовал ранее? Теперь я вошелFloatTensorType([None, 4] потому что у меня есть 4 столбца с характеристиками и None я понятия не имею.

##Convert to ONNX format

initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(treeModel, initial_types=initial_type)
with open("path", "wb") as f:
    f.write(onx.SerializeToString())

В приведенном ниже коде я хочу сделать прогноз, используя onnxruntime библиотека, но я получаю эту ошибку:

RuntimeError: Either type_proto was null or it was not of sequence type

Это потому, что я не понимаю последнюю строку кода ниже. Я вошел в это{input_name: [4, 8, 77.8, 143.45]потому что это четыре значения для столбцов функций. Что я здесь делаю не так?

sess = rt.InferenceSession("pathToONNXModel")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: [4, 8, 77.8, 143.45]})[0]

1 ответ

Решение

Ты пробовал {input_name: numpy.array([4, 8, 77.8, 143.45], dtype=numpy.float32)}? onnxruntime требует в качестве входных данных несколько массивов.

Другие вопросы по тегам