predictor: rebatch tasks to increase CUDA throughput…

- depending on model type (i.e. size), configure target
  batch sizes
- after receiving a prediction task for some model,
  look up target batch size, then try to retrieve arrays
  from follow-up tasks for the same model on the task queue;
  stop when either no tasks are immediately available or
  when the combined batch size (input batch size * number of tasks)
  reaches the target
- push back tasks for other models to the queue
- rebatch: read all shared arrays, concatenate them along axis 0,
  map respective job ids they came from
- predict on new (possibly larger) batch
- split result along axis 0 into number of jobs
- send each result along with its jobid to task queue
This commit is contained in:
Robert Sachunsky 2026-03-14 00:52:34 +01:00
parent b550725cc5
commit 2f3b622cf5

View file

@ -109,29 +109,80 @@ class Predictor(mp.context.SpawnProcess):
while not self.stopped.is_set():
close_all()
try:
jobid, model, shared_data = self.taskq.get(timeout=1.1)
TIMEOUT = 4.5 # 1.1 too is greedy - not enough for rebatching
jobid, model, shared_data = self.taskq.get(timeout=TIMEOUT)
except mp.queues.Empty:
continue
try:
model = self.model_zoo.get(model)
# up to what batch size fits into small (8GB) VRAM?
# (notice we are not listing _resized/_patched models here,
# because here inputs/outputs will have varying shapes)
REBATCH_SIZE = {
# small models (448x448)...
"col_classifier": 4,
"page": 4,
"binarization": 5,
"enhancement": 5,
"reading_order": 5,
# medium size (672x672)...
"textline": 3,
# large models...
"table": 2,
"region_1_2": 2,
"region_fl_np": 2,
"region_fl": 2,
}.get(model, 1)
loaded_model = self.model_zoo.get(model)
if not len(shared_data):
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, model)
result = model.output_shape
result = loaded_model.output_shape
self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result)
else:
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
with ndarray_shared(shared_data) as data:
result = model.predict(data, verbose=0)
other_tasks = [] # other model, put back on q
model_tasks = [] # same model, for rebatching
model_tasks.append((jobid, shared_data))
batch_size = shared_data['shape'][0]
while (not self.taskq.empty() and
# climb to target batch size
batch_size * len(model_tasks) < REBATCH_SIZE):
jobid0, model0, shared_data0 = self.taskq.get()
if model0 == model and len(shared_data0):
# add to our batch
model_tasks.append((jobid0, shared_data0))
else:
other_tasks.append((jobid0, model0, shared_data0))
if len(other_tasks):
self.logger.debug("requeuing %d other tasks", len(other_tasks))
for task in other_tasks:
self.taskq.put(task)
if len(model_tasks) > 1:
self.logger.debug("rebatching %d %s tasks of batch size %d", len(model_tasks), model, batch_size)
with ExitStack() as stack:
data = []
jobs = []
for jobid, shared_data in model_tasks:
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
jobs.append(jobid)
data.append(stack.enter_context(ndarray_shared(shared_data)))
data = np.concatenate(data)
result = loaded_model.predict(data, verbose=0)
results = np.split(result, len(jobs))
#self.logger.debug("sharing result array for '%d'", jobid)
with ExitStack() as stack:
# we don't know when the result will be received,
# but don't want to wait either, so
result = stack.enter_context(share_ndarray(result))
closing[jobid] = stack.pop_all()
for jobid, result in zip(jobs, results):
# we don't know when the result will be received,
# but don't want to wait either, so track closing
# context per job, and wait for closable signal
# from client
result = stack.enter_context(share_ndarray(result))
closing[jobid] = stack.pop_all()
self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result)
except Exception as e:
self.logger.error("prediction '%d' failed: %s", jobid, e.__class__.__name__)
self.logger.error("prediction failed: %s", e.__class__.__name__)
result = e
self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result)
self.resultq.put((jobid, result))
close_all()
#self.logger.debug("predictor terminated")