Ошибка при выводе модели LSTM с помощью onnx-runtime. Ошибка недопустимого аргумента

Я экспортировал модель LSTM из pytorch в onnx . Модель принимает последовательности длиной 200. Она имеет размер скрытого состояния 256, количество слоев = 2. Функция forward принимает входной размер (пакеты, длина последовательности) в качестве входных данных вместе с кортежем, состоящим из скрытого состояния и состояния ячейки. Я получаю сообщение об ошибке при выводе модели со средой выполнения onnx . скрытое состояние и измерения состояния ячейки одинаковы.

ioio1 = np.random.rand(1,200)
ioio2 = np.zeros((2,1,256),dtype = np.float)
pred = runtime_session.run([output_name],{runtime_session.get_inputs()[0].name:ioio1,
                                          runtime_session.get_inputs()[1].name :ioio2,
                                          runtime_session.get_inputs()[2].name : ioio2})
InvalidArgument                           Traceback (most recent call last)
<ipython-input-204-3928823f661e> in <module>()
      1 pred = runtime_session.run([output_name],{runtime_session.get_inputs()[0].name:ioio1,
      2                                           runtime_session.get_inputs()[1].name :ioio2,
----> 3                                           runtime_session.get_inputs()[2].name : ioio2})

/usr/local/lib/python3.6/dist-packages/onnxruntime/capi/session.py in run(self, output_names, input_feed, run_options)
    109             output_names = [output.name for output in self._outputs_meta]
    110         try:
--> 111             return self._sess.run(output_names, input_feed, run_options)
    112         except C.EPFail as err:
    113             if self._enable_fallback:

InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: (N11onnxruntime17PrimitiveDataTypeIdEE) , expected: (N11onnxruntime17PrimitiveDataTypeIlEE)

1 ответ

Эта проблема аналогична: https://github.com/microsoft/onnxruntime/issues/4423 Разрешение: ioio1 = np.random.rand(1,200) is float64 (double), что не соответствует типу dtype, которого ожидает ваша модель.

Другие вопросы по тегам