mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
ModelZoo ONNX backend: handle multiple inputs, too
This commit is contained in:
parent
9d2412080f
commit
08946067ac
1 changed files with 7 additions and 1 deletions
|
|
@ -313,6 +313,7 @@ class EynollahModelZoo:
|
||||||
# 'arena_extend_strategy': 'kNextPowerOfTwo',
|
# 'arena_extend_strategy': 'kNextPowerOfTwo',
|
||||||
'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
|
'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
|
||||||
# 'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
# 'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
||||||
|
#'cudnn_conv_use_max_workspace': 0,
|
||||||
# 'do_copy_in_default_stream': True,
|
# 'do_copy_in_default_stream': True,
|
||||||
# ...
|
# ...
|
||||||
})] + providers
|
})] + providers
|
||||||
|
|
@ -351,7 +352,12 @@ class EynollahModelZoo:
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
return outputs
|
return outputs
|
||||||
model.predict_on_batch = predict_onnx
|
model.predict_on_batch = predict_onnx
|
||||||
model.input_shape = model.get_inputs()[0].shape
|
input_spec = model.get_inputs()
|
||||||
|
input_spec = [i.shape for i in input_spec]
|
||||||
|
if len(input_spec) > 1:
|
||||||
|
model.input_shape = tuple(input_spec)
|
||||||
|
else:
|
||||||
|
model.input_shape = input_spec[0]
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue