Как выполнить пакетный логический вывод с квантованной моделью RoBERTa ONNX?
Я преобразовал модель RoBERTa PyTorch в модель ONNX и проанализировал ее. Я могу получить оценки из модели ONNX для одной точки входных данных (каждое предложение). Я хочу понять, как получить пакетные прогнозы с помощью сеанса вывода ONNX Runtime, передав несколько входных данных в сеанс. Ниже приведен пример сценария.
Модель: roberta-Quant.onnx, которая представляет собой квантованную версию ONNX модели RoBERTa PyTorch.
Код, используемый для преобразования RoBERTa в ONNX:
torch.onnx.export(model,
args=tuple(inputs.values()), # model input
f=export_model_path, # where to save the model
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input_ids', # the model's input names
'attention_mask'],
output_names=['output_0'], # the model's output names
dynamic_axes={'input_ids': symbolic_names, # variable length axes
'attention_mask' : symbolic_names,
'output_0' : {0: 'batch_size'}})
Пример ввода для сеанса логического вывода ONNXRuntime:
{
'input_ids': array([[ 0, 510, 35, 21071, ....., 1, 1, 1, 1, 1, 1]]),
'attention_mask': array([[1, 1, 1, 1, ......., 0, 0, 0, 0, 0, 0]])
}
Запуск модели ONNX для 400 образцов данных (предложений) с использованием сеанса логического вывода ONNXRuntime:
session = onnxruntime.InferenceSession("roberta_quantized.onnx", providers=['CPUExecutionProvider'])
for i in range(400):
ort_inputs = {
'input_ids': input_ids[i].cpu().reshape(1, max_seq_length).numpy(), # max_seq_length=128 here
'input_mask': attention_masks[i].cpu().reshape(1, max_seq_length).numpy()
}
ort_outputs = session.run(None, ort_inputs)
В приведенном выше коде я последовательно просматриваю 400 предложений, чтобы получить оценки "". Пожалуйста, помогите мне понять, как я могу выполнить пакетную обработку здесь, используя модель ONNX, куда я могу отправить
inputs_ids
а также
attention_masks
для нескольких предложений и получить оценки за все предложения в
ort_outputs
.
Заранее спасибо!