Predictor: handle multi-input and/or multi-output cases

This commit is contained in:
Robert Sachunsky 2026-06-02 21:18:22 +02:00
parent c79b73dcc8
commit a391ee24e6

View file

@ -1,5 +1,5 @@
from contextlib import ExitStack from contextlib import ExitStack
from typing import List, Dict from typing import List, Dict, Tuple, Union
import logging import logging
import logging.handlers import logging.handlers
import multiprocessing as mp import multiprocessing as mp
@ -8,6 +8,7 @@ import numpy as np
from .utils.shm import share_ndarray, ndarray_shared from .utils.shm import share_ndarray, ndarray_shared
QSIZE = 200 QSIZE = 200
ArrayT = Union[np.ndarray, Tuple[np.ndarray]]
class Predictor(mp.context.SpawnProcess): class Predictor(mp.context.SpawnProcess):
@ -40,10 +41,10 @@ class Predictor(mp.context.SpawnProcess):
def input_shape(self): def input_shape(self):
return self({}) return self({})
def predict(self, data: dict, verbose=0): def predict(self, data: ArrayT, verbose=0) -> ArrayT:
return self(data) return self(data)
def __call__(self, data: dict): def __call__(self, data: Union[ArrayT, Dict]) -> Union[ArrayT, Tuple]:
# unusable as per python/cpython#79967 # unusable as per python/cpython#79967
#with self.jobid.get_lock(): #with self.jobid.get_lock():
# would work, but not public: # would work, but not public:
@ -55,7 +56,15 @@ class Predictor(mp.context.SpawnProcess):
self.taskq.put((jobid, data)) self.taskq.put((jobid, data))
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name) #self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
return self.result(jobid) return self.result(jobid)
with share_ndarray(data) as shared_data: with ExitStack() as stack:
if isinstance(data, tuple):
# multi-input
shared_data = []
for data0 in data:
shared_data.append(stack.enter_context(share_ndarray(data0)))
shared_data = tuple(shared_data)
else:
shared_data = stack.enter_context(share_ndarray(data))
self.taskq.put((jobid, shared_data)) self.taskq.put((jobid, shared_data))
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data) #self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
return self.result(jobid) return self.result(jobid)
@ -67,6 +76,14 @@ class Predictor(mp.context.SpawnProcess):
result = self.results.pop(jobid) result = self.results.pop(jobid)
if isinstance(result, Exception): if isinstance(result, Exception):
raise Exception(f"predictor {self.name} failed for {jobid}") from result raise Exception(f"predictor {self.name} failed for {jobid}") from result
elif isinstance(result, tuple) and isinstance(result[0], dict):
# multi-output
result1 = []
for result0 in result:
with ndarray_shared(result0) as shared_result0:
result1.append(np.copy(shared_result0))
result = result1
self.closable.append(jobid)
elif isinstance(result, dict): elif isinstance(result, dict):
with ndarray_shared(result) as shared_result: with ndarray_shared(result) as shared_result:
result = np.copy(shared_result) result = np.copy(shared_result)
@ -111,6 +128,7 @@ class Predictor(mp.context.SpawnProcess):
"binarization": 4, "binarization": 4,
"enhancement": 4, "enhancement": 4,
"reading_order": 4, "reading_order": 4,
"ocr": 8,
# medium size (672x672x3)... # medium size (672x672x3)...
"textline": 2, "textline": 2,
# large models... # large models...
@ -126,8 +144,13 @@ class Predictor(mp.context.SpawnProcess):
self.resultq.put((jobid, result)) self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result) #self.logger.debug("sent result for '%d': %s", jobid, result)
else: else:
if isinstance(shared_data, tuple):
multi_input = True
batch_size = shared_data[0]['shape'][0]
else:
multi_input = False
batch_size = shared_data['shape'][0]
tasks = [(jobid, shared_data)] tasks = [(jobid, shared_data)]
batch_size = shared_data['shape'][0]
while (not self.taskq.empty() and while (not self.taskq.empty() and
# climb to target batch size # climb to target batch size
batch_size * len(tasks) < REBATCH_SIZE): batch_size * len(tasks) < REBATCH_SIZE):
@ -136,7 +159,7 @@ class Predictor(mp.context.SpawnProcess):
# add to our batch # add to our batch
tasks.append((jobid0, shared_data0)) tasks.append((jobid0, shared_data0))
else: else:
# immediately anser # immediately answer
self.resultq.put((jobid0, self.model.input_shape)) self.resultq.put((jobid0, self.model.input_shape))
if len(tasks) > 1: if len(tasks) > 1:
self.logger.debug("rebatching %d '%s' tasks of batch size %d", self.logger.debug("rebatching %d '%s' tasks of batch size %d",
@ -147,12 +170,26 @@ class Predictor(mp.context.SpawnProcess):
for jobid, shared_data in tasks: for jobid, shared_data in tasks:
#self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data) #self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
jobs.append(jobid) jobs.append(jobid)
data.append(stack.enter_context(ndarray_shared(shared_data))) if multi_input:
data = np.concatenate(data) data.append(tuple(stack.enter_context(ndarray_shared(shared_data0))
for shared_data0 in shared_data))
else:
data.append(stack.enter_context(ndarray_shared(shared_data)))
if multi_input:
data = tuple(np.concatenate(data0)
for data0 in zip(*data))
else:
data = np.concatenate(data)
#result = self.model.predict(data, verbose=0) #result = self.model.predict(data, verbose=0)
# faster, less VRAM # faster, less VRAM
result = self.model.predict_on_batch(data) result = self.model.predict_on_batch(data)
results = np.split(result, len(jobs)) if isinstance(result, tuple):
multi_output = True
results = zip(*(np.split(result0, len(jobs))
for result0 in result))
else:
multi_output = False
results = np.split(result, len(jobs))
#self.logger.debug("sharing result array for '%d'", jobid) #self.logger.debug("sharing result array for '%d'", jobid)
with ExitStack() as stack: with ExitStack() as stack:
for jobid, result in zip(jobs, results): for jobid, result in zip(jobs, results):
@ -160,7 +197,11 @@ class Predictor(mp.context.SpawnProcess):
# but don't want to wait either, so track closing # but don't want to wait either, so track closing
# context per job, and wait for closable signal # context per job, and wait for closable signal
# from client # from client
result = stack.enter_context(share_ndarray(result)) if multi_output:
result = tuple(stack.enter_context(share_ndarray(result0))
for result0 in result)
else:
result = stack.enter_context(share_ndarray(result))
closing[jobid] = stack.pop_all() closing[jobid] = stack.pop_all()
self.resultq.put((jobid, result)) self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result) #self.logger.debug("sent result for '%d': %s", jobid, result)