mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
Predictor: handle multi-input and/or multi-output cases
This commit is contained in:
parent
c79b73dcc8
commit
a391ee24e6
1 changed files with 51 additions and 10 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Tuple, Union
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
@ -8,6 +8,7 @@ import numpy as np
|
||||||
from .utils.shm import share_ndarray, ndarray_shared
|
from .utils.shm import share_ndarray, ndarray_shared
|
||||||
|
|
||||||
QSIZE = 200
|
QSIZE = 200
|
||||||
|
ArrayT = Union[np.ndarray, Tuple[np.ndarray]]
|
||||||
|
|
||||||
|
|
||||||
class Predictor(mp.context.SpawnProcess):
|
class Predictor(mp.context.SpawnProcess):
|
||||||
|
|
@ -40,10 +41,10 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
def input_shape(self):
|
def input_shape(self):
|
||||||
return self({})
|
return self({})
|
||||||
|
|
||||||
def predict(self, data: dict, verbose=0):
|
def predict(self, data: ArrayT, verbose=0) -> ArrayT:
|
||||||
return self(data)
|
return self(data)
|
||||||
|
|
||||||
def __call__(self, data: dict):
|
def __call__(self, data: Union[ArrayT, Dict]) -> Union[ArrayT, Tuple]:
|
||||||
# unusable as per python/cpython#79967
|
# unusable as per python/cpython#79967
|
||||||
#with self.jobid.get_lock():
|
#with self.jobid.get_lock():
|
||||||
# would work, but not public:
|
# would work, but not public:
|
||||||
|
|
@ -55,7 +56,15 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
self.taskq.put((jobid, data))
|
self.taskq.put((jobid, data))
|
||||||
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
|
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
|
||||||
return self.result(jobid)
|
return self.result(jobid)
|
||||||
with share_ndarray(data) as shared_data:
|
with ExitStack() as stack:
|
||||||
|
if isinstance(data, tuple):
|
||||||
|
# multi-input
|
||||||
|
shared_data = []
|
||||||
|
for data0 in data:
|
||||||
|
shared_data.append(stack.enter_context(share_ndarray(data0)))
|
||||||
|
shared_data = tuple(shared_data)
|
||||||
|
else:
|
||||||
|
shared_data = stack.enter_context(share_ndarray(data))
|
||||||
self.taskq.put((jobid, shared_data))
|
self.taskq.put((jobid, shared_data))
|
||||||
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
|
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
|
||||||
return self.result(jobid)
|
return self.result(jobid)
|
||||||
|
|
@ -67,6 +76,14 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
result = self.results.pop(jobid)
|
result = self.results.pop(jobid)
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
raise Exception(f"predictor {self.name} failed for {jobid}") from result
|
raise Exception(f"predictor {self.name} failed for {jobid}") from result
|
||||||
|
elif isinstance(result, tuple) and isinstance(result[0], dict):
|
||||||
|
# multi-output
|
||||||
|
result1 = []
|
||||||
|
for result0 in result:
|
||||||
|
with ndarray_shared(result0) as shared_result0:
|
||||||
|
result1.append(np.copy(shared_result0))
|
||||||
|
result = result1
|
||||||
|
self.closable.append(jobid)
|
||||||
elif isinstance(result, dict):
|
elif isinstance(result, dict):
|
||||||
with ndarray_shared(result) as shared_result:
|
with ndarray_shared(result) as shared_result:
|
||||||
result = np.copy(shared_result)
|
result = np.copy(shared_result)
|
||||||
|
|
@ -111,6 +128,7 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
"binarization": 4,
|
"binarization": 4,
|
||||||
"enhancement": 4,
|
"enhancement": 4,
|
||||||
"reading_order": 4,
|
"reading_order": 4,
|
||||||
|
"ocr": 8,
|
||||||
# medium size (672x672x3)...
|
# medium size (672x672x3)...
|
||||||
"textline": 2,
|
"textline": 2,
|
||||||
# large models...
|
# large models...
|
||||||
|
|
@ -126,8 +144,13 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
self.resultq.put((jobid, result))
|
self.resultq.put((jobid, result))
|
||||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||||
else:
|
else:
|
||||||
|
if isinstance(shared_data, tuple):
|
||||||
|
multi_input = True
|
||||||
|
batch_size = shared_data[0]['shape'][0]
|
||||||
|
else:
|
||||||
|
multi_input = False
|
||||||
|
batch_size = shared_data['shape'][0]
|
||||||
tasks = [(jobid, shared_data)]
|
tasks = [(jobid, shared_data)]
|
||||||
batch_size = shared_data['shape'][0]
|
|
||||||
while (not self.taskq.empty() and
|
while (not self.taskq.empty() and
|
||||||
# climb to target batch size
|
# climb to target batch size
|
||||||
batch_size * len(tasks) < REBATCH_SIZE):
|
batch_size * len(tasks) < REBATCH_SIZE):
|
||||||
|
|
@ -136,7 +159,7 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
# add to our batch
|
# add to our batch
|
||||||
tasks.append((jobid0, shared_data0))
|
tasks.append((jobid0, shared_data0))
|
||||||
else:
|
else:
|
||||||
# immediately anser
|
# immediately answer
|
||||||
self.resultq.put((jobid0, self.model.input_shape))
|
self.resultq.put((jobid0, self.model.input_shape))
|
||||||
if len(tasks) > 1:
|
if len(tasks) > 1:
|
||||||
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
|
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
|
||||||
|
|
@ -147,12 +170,26 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
for jobid, shared_data in tasks:
|
for jobid, shared_data in tasks:
|
||||||
#self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
|
#self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
|
||||||
jobs.append(jobid)
|
jobs.append(jobid)
|
||||||
data.append(stack.enter_context(ndarray_shared(shared_data)))
|
if multi_input:
|
||||||
data = np.concatenate(data)
|
data.append(tuple(stack.enter_context(ndarray_shared(shared_data0))
|
||||||
|
for shared_data0 in shared_data))
|
||||||
|
else:
|
||||||
|
data.append(stack.enter_context(ndarray_shared(shared_data)))
|
||||||
|
if multi_input:
|
||||||
|
data = tuple(np.concatenate(data0)
|
||||||
|
for data0 in zip(*data))
|
||||||
|
else:
|
||||||
|
data = np.concatenate(data)
|
||||||
#result = self.model.predict(data, verbose=0)
|
#result = self.model.predict(data, verbose=0)
|
||||||
# faster, less VRAM
|
# faster, less VRAM
|
||||||
result = self.model.predict_on_batch(data)
|
result = self.model.predict_on_batch(data)
|
||||||
results = np.split(result, len(jobs))
|
if isinstance(result, tuple):
|
||||||
|
multi_output = True
|
||||||
|
results = zip(*(np.split(result0, len(jobs))
|
||||||
|
for result0 in result))
|
||||||
|
else:
|
||||||
|
multi_output = False
|
||||||
|
results = np.split(result, len(jobs))
|
||||||
#self.logger.debug("sharing result array for '%d'", jobid)
|
#self.logger.debug("sharing result array for '%d'", jobid)
|
||||||
with ExitStack() as stack:
|
with ExitStack() as stack:
|
||||||
for jobid, result in zip(jobs, results):
|
for jobid, result in zip(jobs, results):
|
||||||
|
|
@ -160,7 +197,11 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
# but don't want to wait either, so track closing
|
# but don't want to wait either, so track closing
|
||||||
# context per job, and wait for closable signal
|
# context per job, and wait for closable signal
|
||||||
# from client
|
# from client
|
||||||
result = stack.enter_context(share_ndarray(result))
|
if multi_output:
|
||||||
|
result = tuple(stack.enter_context(share_ndarray(result0))
|
||||||
|
for result0 in result)
|
||||||
|
else:
|
||||||
|
result = stack.enter_context(share_ndarray(result))
|
||||||
closing[jobid] = stack.pop_all()
|
closing[jobid] = stack.pop_all()
|
||||||
self.resultq.put((jobid, result))
|
self.resultq.put((jobid, result))
|
||||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue