mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
Predictor: handle multi-input and/or multi-output cases
This commit is contained in:
parent
c79b73dcc8
commit
a391ee24e6
1 changed files with 51 additions and 10 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from contextlib import ExitStack
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple, Union
|
||||
import logging
|
||||
import logging.handlers
|
||||
import multiprocessing as mp
|
||||
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
from .utils.shm import share_ndarray, ndarray_shared
|
||||
|
||||
QSIZE = 200
|
||||
ArrayT = Union[np.ndarray, Tuple[np.ndarray]]
|
||||
|
||||
|
||||
class Predictor(mp.context.SpawnProcess):
|
||||
|
|
@ -40,10 +41,10 @@ class Predictor(mp.context.SpawnProcess):
|
|||
def input_shape(self):
|
||||
return self({})
|
||||
|
||||
def predict(self, data: dict, verbose=0):
|
||||
def predict(self, data: ArrayT, verbose=0) -> ArrayT:
|
||||
return self(data)
|
||||
|
||||
def __call__(self, data: dict):
|
||||
def __call__(self, data: Union[ArrayT, Dict]) -> Union[ArrayT, Tuple]:
|
||||
# unusable as per python/cpython#79967
|
||||
#with self.jobid.get_lock():
|
||||
# would work, but not public:
|
||||
|
|
@ -55,7 +56,15 @@ class Predictor(mp.context.SpawnProcess):
|
|||
self.taskq.put((jobid, data))
|
||||
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
|
||||
return self.result(jobid)
|
||||
with share_ndarray(data) as shared_data:
|
||||
with ExitStack() as stack:
|
||||
if isinstance(data, tuple):
|
||||
# multi-input
|
||||
shared_data = []
|
||||
for data0 in data:
|
||||
shared_data.append(stack.enter_context(share_ndarray(data0)))
|
||||
shared_data = tuple(shared_data)
|
||||
else:
|
||||
shared_data = stack.enter_context(share_ndarray(data))
|
||||
self.taskq.put((jobid, shared_data))
|
||||
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
|
||||
return self.result(jobid)
|
||||
|
|
@ -67,6 +76,14 @@ class Predictor(mp.context.SpawnProcess):
|
|||
result = self.results.pop(jobid)
|
||||
if isinstance(result, Exception):
|
||||
raise Exception(f"predictor {self.name} failed for {jobid}") from result
|
||||
elif isinstance(result, tuple) and isinstance(result[0], dict):
|
||||
# multi-output
|
||||
result1 = []
|
||||
for result0 in result:
|
||||
with ndarray_shared(result0) as shared_result0:
|
||||
result1.append(np.copy(shared_result0))
|
||||
result = result1
|
||||
self.closable.append(jobid)
|
||||
elif isinstance(result, dict):
|
||||
with ndarray_shared(result) as shared_result:
|
||||
result = np.copy(shared_result)
|
||||
|
|
@ -111,6 +128,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
"binarization": 4,
|
||||
"enhancement": 4,
|
||||
"reading_order": 4,
|
||||
"ocr": 8,
|
||||
# medium size (672x672x3)...
|
||||
"textline": 2,
|
||||
# large models...
|
||||
|
|
@ -126,8 +144,13 @@ class Predictor(mp.context.SpawnProcess):
|
|||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
else:
|
||||
if isinstance(shared_data, tuple):
|
||||
multi_input = True
|
||||
batch_size = shared_data[0]['shape'][0]
|
||||
else:
|
||||
multi_input = False
|
||||
batch_size = shared_data['shape'][0]
|
||||
tasks = [(jobid, shared_data)]
|
||||
batch_size = shared_data['shape'][0]
|
||||
while (not self.taskq.empty() and
|
||||
# climb to target batch size
|
||||
batch_size * len(tasks) < REBATCH_SIZE):
|
||||
|
|
@ -136,7 +159,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
# add to our batch
|
||||
tasks.append((jobid0, shared_data0))
|
||||
else:
|
||||
# immediately anser
|
||||
# immediately answer
|
||||
self.resultq.put((jobid0, self.model.input_shape))
|
||||
if len(tasks) > 1:
|
||||
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
|
||||
|
|
@ -147,12 +170,26 @@ class Predictor(mp.context.SpawnProcess):
|
|||
for jobid, shared_data in tasks:
|
||||
#self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
|
||||
jobs.append(jobid)
|
||||
data.append(stack.enter_context(ndarray_shared(shared_data)))
|
||||
data = np.concatenate(data)
|
||||
if multi_input:
|
||||
data.append(tuple(stack.enter_context(ndarray_shared(shared_data0))
|
||||
for shared_data0 in shared_data))
|
||||
else:
|
||||
data.append(stack.enter_context(ndarray_shared(shared_data)))
|
||||
if multi_input:
|
||||
data = tuple(np.concatenate(data0)
|
||||
for data0 in zip(*data))
|
||||
else:
|
||||
data = np.concatenate(data)
|
||||
#result = self.model.predict(data, verbose=0)
|
||||
# faster, less VRAM
|
||||
result = self.model.predict_on_batch(data)
|
||||
results = np.split(result, len(jobs))
|
||||
if isinstance(result, tuple):
|
||||
multi_output = True
|
||||
results = zip(*(np.split(result0, len(jobs))
|
||||
for result0 in result))
|
||||
else:
|
||||
multi_output = False
|
||||
results = np.split(result, len(jobs))
|
||||
#self.logger.debug("sharing result array for '%d'", jobid)
|
||||
with ExitStack() as stack:
|
||||
for jobid, result in zip(jobs, results):
|
||||
|
|
@ -160,7 +197,11 @@ class Predictor(mp.context.SpawnProcess):
|
|||
# but don't want to wait either, so track closing
|
||||
# context per job, and wait for closable signal
|
||||
# from client
|
||||
result = stack.enter_context(share_ndarray(result))
|
||||
if multi_output:
|
||||
result = tuple(stack.enter_context(share_ndarray(result0))
|
||||
for result0 in result)
|
||||
else:
|
||||
result = stack.enter_context(share_ndarray(result))
|
||||
closing[jobid] = stack.pop_all()
|
||||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue