From 27ca9733db22d9a406b25d69f20654f4b9f44743 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 3 Jun 2026 20:57:02 +0200 Subject: [PATCH] ModelZoo ONNX backend for inference: support multi-input or -output --- src/eynollah/model_zoo/model_zoo.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 49ed8e1..51ce909 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -331,14 +331,25 @@ class EynollahModelZoo: model_path, providers=providers) # FIXME: notify about selected provider/device - input_name = model.get_inputs()[0].name - output_name = model.get_outputs()[0].name + model_inputs = [model_input.name + for model_input in model.get_inputs()] + model_outputs = [model_output.name + for model_output in model.get_outputs()] def predict_onnx(inputs): - # models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)' - # FIXME: do this dynamically (but how to convert .type to np.dtype?) - inputs = inputs.astype(np.float32) - return model.run( - [output_name], {input_name: inputs})[0] + if len(model_inputs) == 1: + inputs = [inputs] + outputs = model.run(model_outputs, { + model_input: + input_data.astype( + # models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)' + # FIXME: do this dynamically (but how to convert .type to np.dtype?) + np.float32 if input_data.dtype in [np.float16, np.float64] else + input_data.dtype) + for model_input, input_data in zip(model_inputs, inputs) + }) + if len(model_outputs) == 1: + outputs = outputs[0] + return outputs model.predict_on_batch = predict_onnx model.input_shape = model.get_inputs()[0].shape