From 2f3b622cf5e23004b2ba991f1c1a19d6d8bd4adb Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sat, 14 Mar 2026 00:52:34 +0100 Subject: [PATCH] =?UTF-8?q?predictor:=20rebatch=20tasks=20to=20increase=20?= =?UTF-8?q?CUDA=20throughput=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - depending on model type (i.e. size), configure target batch sizes - after receiving a prediction task for some model, look up target batch size, then try to retrieve arrays from follow-up tasks for the same model on the task queue; stop when either no tasks are immediately available or when the combined batch size (input batch size * number of tasks) reaches the target - push back tasks for other models to the queue - rebatch: read all shared arrays, concatenate them along axis 0, map respective job ids they came from - predict on new (possibly larger) batch - split result along axis 0 into number of jobs - send each result along with its jobid to task queue --- src/eynollah/predictor.py | 77 ++++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/src/eynollah/predictor.py b/src/eynollah/predictor.py index 9faa1be..8b46250 100644 --- a/src/eynollah/predictor.py +++ b/src/eynollah/predictor.py @@ -109,29 +109,80 @@ class Predictor(mp.context.SpawnProcess): while not self.stopped.is_set(): close_all() try: - jobid, model, shared_data = self.taskq.get(timeout=1.1) + TIMEOUT = 4.5 # 1.1 too is greedy - not enough for rebatching + jobid, model, shared_data = self.taskq.get(timeout=TIMEOUT) except mp.queues.Empty: continue try: - model = self.model_zoo.get(model) + # up to what batch size fits into small (8GB) VRAM? + # (notice we are not listing _resized/_patched models here, + # because here inputs/outputs will have varying shapes) + REBATCH_SIZE = { + # small models (448x448)... + "col_classifier": 4, + "page": 4, + "binarization": 5, + "enhancement": 5, + "reading_order": 5, + # medium size (672x672)... + "textline": 3, + # large models... + "table": 2, + "region_1_2": 2, + "region_fl_np": 2, + "region_fl": 2, + }.get(model, 1) + loaded_model = self.model_zoo.get(model) if not len(shared_data): #self.logger.debug("getting '%d' output shape of model '%s'", jobid, model) - result = model.output_shape + result = loaded_model.output_shape + self.resultq.put((jobid, result)) + #self.logger.debug("sent result for '%d': %s", jobid, result) else: - #self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data) - with ndarray_shared(shared_data) as data: - result = model.predict(data, verbose=0) + other_tasks = [] # other model, put back on q + model_tasks = [] # same model, for rebatching + model_tasks.append((jobid, shared_data)) + batch_size = shared_data['shape'][0] + while (not self.taskq.empty() and + # climb to target batch size + batch_size * len(model_tasks) < REBATCH_SIZE): + jobid0, model0, shared_data0 = self.taskq.get() + if model0 == model and len(shared_data0): + # add to our batch + model_tasks.append((jobid0, shared_data0)) + else: + other_tasks.append((jobid0, model0, shared_data0)) + if len(other_tasks): + self.logger.debug("requeuing %d other tasks", len(other_tasks)) + for task in other_tasks: + self.taskq.put(task) + if len(model_tasks) > 1: + self.logger.debug("rebatching %d %s tasks of batch size %d", len(model_tasks), model, batch_size) + with ExitStack() as stack: + data = [] + jobs = [] + for jobid, shared_data in model_tasks: + #self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data) + jobs.append(jobid) + data.append(stack.enter_context(ndarray_shared(shared_data))) + data = np.concatenate(data) + result = loaded_model.predict(data, verbose=0) + results = np.split(result, len(jobs)) #self.logger.debug("sharing result array for '%d'", jobid) with ExitStack() as stack: - # we don't know when the result will be received, - # but don't want to wait either, so - result = stack.enter_context(share_ndarray(result)) - closing[jobid] = stack.pop_all() + for jobid, result in zip(jobs, results): + # we don't know when the result will be received, + # but don't want to wait either, so track closing + # context per job, and wait for closable signal + # from client + result = stack.enter_context(share_ndarray(result)) + closing[jobid] = stack.pop_all() + self.resultq.put((jobid, result)) + #self.logger.debug("sent result for '%d': %s", jobid, result) except Exception as e: - self.logger.error("prediction '%d' failed: %s", jobid, e.__class__.__name__) + self.logger.error("prediction failed: %s", e.__class__.__name__) result = e - self.resultq.put((jobid, result)) - #self.logger.debug("sent result for '%d': %s", jobid, result) + self.resultq.put((jobid, result)) close_all() #self.logger.debug("predictor terminated")