ModelZoo ONNX backend: handle multiple inputs, too

This commit is contained in:
Robert Sachunsky 2026-06-12 14:54:51 +02:00
parent 9d2412080f
commit 08946067ac

View file

@ -313,6 +313,7 @@ class EynollahModelZoo:
# 'arena_extend_strategy': 'kNextPowerOfTwo', # 'arena_extend_strategy': 'kNextPowerOfTwo',
'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024, 'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
# 'cudnn_conv_algo_search': 'EXHAUSTIVE', # 'cudnn_conv_algo_search': 'EXHAUSTIVE',
#'cudnn_conv_use_max_workspace': 0,
# 'do_copy_in_default_stream': True, # 'do_copy_in_default_stream': True,
# ... # ...
})] + providers })] + providers
@ -351,7 +352,12 @@ class EynollahModelZoo:
outputs = outputs[0] outputs = outputs[0]
return outputs return outputs
model.predict_on_batch = predict_onnx model.predict_on_batch = predict_onnx
model.input_shape = model.get_inputs()[0].shape input_spec = model.get_inputs()
input_spec = [i.shape for i in input_spec]
if len(input_spec) > 1:
model.input_shape = tuple(input_spec)
else:
model.input_shape = input_spec[0]
return model return model