diff --git a/src/eynollah/predictor.py b/src/eynollah/predictor.py index 3c6890e..141d3f0 100644 --- a/src/eynollah/predictor.py +++ b/src/eynollah/predictor.py @@ -1,5 +1,5 @@ from contextlib import ExitStack -from typing import List, Dict +from typing import List, Dict, Tuple, Union import logging import logging.handlers import multiprocessing as mp @@ -8,6 +8,7 @@ import numpy as np from .utils.shm import share_ndarray, ndarray_shared QSIZE = 200 +ArrayT = Union[np.ndarray, Tuple[np.ndarray]] class Predictor(mp.context.SpawnProcess): @@ -40,10 +41,10 @@ class Predictor(mp.context.SpawnProcess): def input_shape(self): return self({}) - def predict(self, data: dict, verbose=0): + def predict(self, data: ArrayT, verbose=0) -> ArrayT: return self(data) - def __call__(self, data: dict): + def __call__(self, data: Union[ArrayT, Dict]) -> Union[ArrayT, Tuple]: # unusable as per python/cpython#79967 #with self.jobid.get_lock(): # would work, but not public: @@ -55,7 +56,15 @@ class Predictor(mp.context.SpawnProcess): self.taskq.put((jobid, data)) #self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name) return self.result(jobid) - with share_ndarray(data) as shared_data: + with ExitStack() as stack: + if isinstance(data, tuple): + # multi-input + shared_data = [] + for data0 in data: + shared_data.append(stack.enter_context(share_ndarray(data0))) + shared_data = tuple(shared_data) + else: + shared_data = stack.enter_context(share_ndarray(data)) self.taskq.put((jobid, shared_data)) #self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data) return self.result(jobid) @@ -67,6 +76,14 @@ class Predictor(mp.context.SpawnProcess): result = self.results.pop(jobid) if isinstance(result, Exception): raise Exception(f"predictor {self.name} failed for {jobid}") from result + elif isinstance(result, tuple) and isinstance(result[0], dict): + # multi-output + result1 = [] + for result0 in result: + with ndarray_shared(result0) as shared_result0: + result1.append(np.copy(shared_result0)) + result = result1 + self.closable.append(jobid) elif isinstance(result, dict): with ndarray_shared(result) as shared_result: result = np.copy(shared_result) @@ -111,6 +128,7 @@ class Predictor(mp.context.SpawnProcess): "binarization": 4, "enhancement": 4, "reading_order": 4, + "ocr": 8, # medium size (672x672x3)... "textline": 2, # large models... @@ -126,8 +144,13 @@ class Predictor(mp.context.SpawnProcess): self.resultq.put((jobid, result)) #self.logger.debug("sent result for '%d': %s", jobid, result) else: + if isinstance(shared_data, tuple): + multi_input = True + batch_size = shared_data[0]['shape'][0] + else: + multi_input = False + batch_size = shared_data['shape'][0] tasks = [(jobid, shared_data)] - batch_size = shared_data['shape'][0] while (not self.taskq.empty() and # climb to target batch size batch_size * len(tasks) < REBATCH_SIZE): @@ -136,7 +159,7 @@ class Predictor(mp.context.SpawnProcess): # add to our batch tasks.append((jobid0, shared_data0)) else: - # immediately anser + # immediately answer self.resultq.put((jobid0, self.model.input_shape)) if len(tasks) > 1: self.logger.debug("rebatching %d '%s' tasks of batch size %d", @@ -147,12 +170,26 @@ class Predictor(mp.context.SpawnProcess): for jobid, shared_data in tasks: #self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data) jobs.append(jobid) - data.append(stack.enter_context(ndarray_shared(shared_data))) - data = np.concatenate(data) + if multi_input: + data.append(tuple(stack.enter_context(ndarray_shared(shared_data0)) + for shared_data0 in shared_data)) + else: + data.append(stack.enter_context(ndarray_shared(shared_data))) + if multi_input: + data = tuple(np.concatenate(data0) + for data0 in zip(*data)) + else: + data = np.concatenate(data) #result = self.model.predict(data, verbose=0) # faster, less VRAM result = self.model.predict_on_batch(data) - results = np.split(result, len(jobs)) + if isinstance(result, tuple): + multi_output = True + results = zip(*(np.split(result0, len(jobs)) + for result0 in result)) + else: + multi_output = False + results = np.split(result, len(jobs)) #self.logger.debug("sharing result array for '%d'", jobid) with ExitStack() as stack: for jobid, result in zip(jobs, results): @@ -160,7 +197,11 @@ class Predictor(mp.context.SpawnProcess): # but don't want to wait either, so track closing # context per job, and wait for closable signal # from client - result = stack.enter_context(share_ndarray(result)) + if multi_output: + result = tuple(stack.enter_context(share_ndarray(result0)) + for result0 in result) + else: + result = stack.enter_context(share_ndarray(result)) closing[jobid] = stack.pop_all() self.resultq.put((jobid, result)) #self.logger.debug("sent result for '%d': %s", jobid, result)