mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-14 19:31:57 +02:00
predictor: rebatch tasks to increase CUDA throughput…
- depending on model type (i.e. size), configure target batch sizes - after receiving a prediction task for some model, look up target batch size, then try to retrieve arrays from follow-up tasks for the same model on the task queue; stop when either no tasks are immediately available or when the combined batch size (input batch size * number of tasks) reaches the target - push back tasks for other models to the queue - rebatch: read all shared arrays, concatenate them along axis 0, map respective job ids they came from - predict on new (possibly larger) batch - split result along axis 0 into number of jobs - send each result along with its jobid to task queue
This commit is contained in:
parent
b550725cc5
commit
2f3b622cf5
1 changed files with 64 additions and 13 deletions
|
|
@ -109,29 +109,80 @@ class Predictor(mp.context.SpawnProcess):
|
|||
while not self.stopped.is_set():
|
||||
close_all()
|
||||
try:
|
||||
jobid, model, shared_data = self.taskq.get(timeout=1.1)
|
||||
TIMEOUT = 4.5 # 1.1 too is greedy - not enough for rebatching
|
||||
jobid, model, shared_data = self.taskq.get(timeout=TIMEOUT)
|
||||
except mp.queues.Empty:
|
||||
continue
|
||||
try:
|
||||
model = self.model_zoo.get(model)
|
||||
# up to what batch size fits into small (8GB) VRAM?
|
||||
# (notice we are not listing _resized/_patched models here,
|
||||
# because here inputs/outputs will have varying shapes)
|
||||
REBATCH_SIZE = {
|
||||
# small models (448x448)...
|
||||
"col_classifier": 4,
|
||||
"page": 4,
|
||||
"binarization": 5,
|
||||
"enhancement": 5,
|
||||
"reading_order": 5,
|
||||
# medium size (672x672)...
|
||||
"textline": 3,
|
||||
# large models...
|
||||
"table": 2,
|
||||
"region_1_2": 2,
|
||||
"region_fl_np": 2,
|
||||
"region_fl": 2,
|
||||
}.get(model, 1)
|
||||
loaded_model = self.model_zoo.get(model)
|
||||
if not len(shared_data):
|
||||
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, model)
|
||||
result = model.output_shape
|
||||
result = loaded_model.output_shape
|
||||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
else:
|
||||
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
|
||||
with ndarray_shared(shared_data) as data:
|
||||
result = model.predict(data, verbose=0)
|
||||
other_tasks = [] # other model, put back on q
|
||||
model_tasks = [] # same model, for rebatching
|
||||
model_tasks.append((jobid, shared_data))
|
||||
batch_size = shared_data['shape'][0]
|
||||
while (not self.taskq.empty() and
|
||||
# climb to target batch size
|
||||
batch_size * len(model_tasks) < REBATCH_SIZE):
|
||||
jobid0, model0, shared_data0 = self.taskq.get()
|
||||
if model0 == model and len(shared_data0):
|
||||
# add to our batch
|
||||
model_tasks.append((jobid0, shared_data0))
|
||||
else:
|
||||
other_tasks.append((jobid0, model0, shared_data0))
|
||||
if len(other_tasks):
|
||||
self.logger.debug("requeuing %d other tasks", len(other_tasks))
|
||||
for task in other_tasks:
|
||||
self.taskq.put(task)
|
||||
if len(model_tasks) > 1:
|
||||
self.logger.debug("rebatching %d %s tasks of batch size %d", len(model_tasks), model, batch_size)
|
||||
with ExitStack() as stack:
|
||||
data = []
|
||||
jobs = []
|
||||
for jobid, shared_data in model_tasks:
|
||||
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
|
||||
jobs.append(jobid)
|
||||
data.append(stack.enter_context(ndarray_shared(shared_data)))
|
||||
data = np.concatenate(data)
|
||||
result = loaded_model.predict(data, verbose=0)
|
||||
results = np.split(result, len(jobs))
|
||||
#self.logger.debug("sharing result array for '%d'", jobid)
|
||||
with ExitStack() as stack:
|
||||
# we don't know when the result will be received,
|
||||
# but don't want to wait either, so
|
||||
result = stack.enter_context(share_ndarray(result))
|
||||
closing[jobid] = stack.pop_all()
|
||||
for jobid, result in zip(jobs, results):
|
||||
# we don't know when the result will be received,
|
||||
# but don't want to wait either, so track closing
|
||||
# context per job, and wait for closable signal
|
||||
# from client
|
||||
result = stack.enter_context(share_ndarray(result))
|
||||
closing[jobid] = stack.pop_all()
|
||||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
except Exception as e:
|
||||
self.logger.error("prediction '%d' failed: %s", jobid, e.__class__.__name__)
|
||||
self.logger.error("prediction failed: %s", e.__class__.__name__)
|
||||
result = e
|
||||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
self.resultq.put((jobid, result))
|
||||
close_all()
|
||||
#self.logger.debug("predictor terminated")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue