Predictor: handle multi-input and/or multi-output cases

This commit is contained in:
Robert Sachunsky 2026-06-02 21:18:22 +02:00
parent c79b73dcc8
commit a391ee24e6

View file

@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import List, Dict
from typing import List, Dict, Tuple, Union
import logging
import logging.handlers
import multiprocessing as mp
@ -8,6 +8,7 @@ import numpy as np
from .utils.shm import share_ndarray, ndarray_shared
QSIZE = 200
ArrayT = Union[np.ndarray, Tuple[np.ndarray]]
class Predictor(mp.context.SpawnProcess):
@ -40,10 +41,10 @@ class Predictor(mp.context.SpawnProcess):
def input_shape(self):
return self({})
def predict(self, data: dict, verbose=0):
def predict(self, data: ArrayT, verbose=0) -> ArrayT:
return self(data)
def __call__(self, data: dict):
def __call__(self, data: Union[ArrayT, Dict]) -> Union[ArrayT, Tuple]:
# unusable as per python/cpython#79967
#with self.jobid.get_lock():
# would work, but not public:
@ -55,7 +56,15 @@ class Predictor(mp.context.SpawnProcess):
self.taskq.put((jobid, data))
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
return self.result(jobid)
with share_ndarray(data) as shared_data:
with ExitStack() as stack:
if isinstance(data, tuple):
# multi-input
shared_data = []
for data0 in data:
shared_data.append(stack.enter_context(share_ndarray(data0)))
shared_data = tuple(shared_data)
else:
shared_data = stack.enter_context(share_ndarray(data))
self.taskq.put((jobid, shared_data))
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
return self.result(jobid)
@ -67,6 +76,14 @@ class Predictor(mp.context.SpawnProcess):
result = self.results.pop(jobid)
if isinstance(result, Exception):
raise Exception(f"predictor {self.name} failed for {jobid}") from result
elif isinstance(result, tuple) and isinstance(result[0], dict):
# multi-output
result1 = []
for result0 in result:
with ndarray_shared(result0) as shared_result0:
result1.append(np.copy(shared_result0))
result = result1
self.closable.append(jobid)
elif isinstance(result, dict):
with ndarray_shared(result) as shared_result:
result = np.copy(shared_result)
@ -111,6 +128,7 @@ class Predictor(mp.context.SpawnProcess):
"binarization": 4,
"enhancement": 4,
"reading_order": 4,
"ocr": 8,
# medium size (672x672x3)...
"textline": 2,
# large models...
@ -126,8 +144,13 @@ class Predictor(mp.context.SpawnProcess):
self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result)
else:
tasks = [(jobid, shared_data)]
if isinstance(shared_data, tuple):
multi_input = True
batch_size = shared_data[0]['shape'][0]
else:
multi_input = False
batch_size = shared_data['shape'][0]
tasks = [(jobid, shared_data)]
while (not self.taskq.empty() and
# climb to target batch size
batch_size * len(tasks) < REBATCH_SIZE):
@ -136,7 +159,7 @@ class Predictor(mp.context.SpawnProcess):
# add to our batch
tasks.append((jobid0, shared_data0))
else:
# immediately anser
# immediately answer
self.resultq.put((jobid0, self.model.input_shape))
if len(tasks) > 1:
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
@ -147,11 +170,25 @@ class Predictor(mp.context.SpawnProcess):
for jobid, shared_data in tasks:
#self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
jobs.append(jobid)
if multi_input:
data.append(tuple(stack.enter_context(ndarray_shared(shared_data0))
for shared_data0 in shared_data))
else:
data.append(stack.enter_context(ndarray_shared(shared_data)))
if multi_input:
data = tuple(np.concatenate(data0)
for data0 in zip(*data))
else:
data = np.concatenate(data)
#result = self.model.predict(data, verbose=0)
# faster, less VRAM
result = self.model.predict_on_batch(data)
if isinstance(result, tuple):
multi_output = True
results = zip(*(np.split(result0, len(jobs))
for result0 in result))
else:
multi_output = False
results = np.split(result, len(jobs))
#self.logger.debug("sharing result array for '%d'", jobid)
with ExitStack() as stack:
@ -160,6 +197,10 @@ class Predictor(mp.context.SpawnProcess):
# but don't want to wait either, so track closing
# context per job, and wait for closable signal
# from client
if multi_output:
result = tuple(stack.enter_context(share_ndarray(result0))
for result0 in result)
else:
result = stack.enter_context(share_ndarray(result))
closing[jobid] = stack.pop_all()
self.resultq.put((jobid, result))