mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-24 08:02:45 +01:00
predictor: fix spawn vs fork / parent vs child contexts
This commit is contained in:
parent
64281768a9
commit
800c55b826
2 changed files with 31 additions and 21 deletions
|
|
@ -190,8 +190,9 @@ class Eynollah:
|
|||
|
||||
self.model_zoo.load_models(*loadable)
|
||||
for model in loadable:
|
||||
# cache and retrieve output shapes
|
||||
self.model_zoo.get(model).output_shape
|
||||
# retrieve and cache output shapes
|
||||
self.logger.debug("model %s has output shape %s", model,
|
||||
self.model_zoo.get(model).output_shape)
|
||||
|
||||
def __del__(self):
|
||||
if executor := getattr(self, 'executor', None):
|
||||
|
|
|
|||
|
|
@ -36,19 +36,19 @@ class Predictor(mp.context.SpawnProcess):
|
|||
|
||||
def __init__(self, logger, model_zoo):
|
||||
self.logger = logger
|
||||
self.loglevel = logger.level
|
||||
self.model_zoo = model_zoo
|
||||
ctxt = mp.get_context('spawn')
|
||||
self.jobid = ctxt.Value('i', 0)
|
||||
self.taskq = ctxt.Queue(maxsize=QSIZE)
|
||||
self.resultq = ctxt.Queue(maxsize=QSIZE)
|
||||
self.logq = ctxt.Queue(maxsize=QSIZE * 100)
|
||||
log_listener = logging.handlers.QueueListener(
|
||||
self.logq, *self.logger.handlers,
|
||||
respect_handler_level=True).start()
|
||||
logging.handlers.QueueListener(
|
||||
self.logq, *(
|
||||
# as per ocrd_utils.initLogging():
|
||||
logging.root.handlers +
|
||||
# as per eynollah_cli.main():
|
||||
self.logger.handlers
|
||||
), respect_handler_level=False).start()
|
||||
self.stopped = ctxt.Event()
|
||||
ctxt = mp.get_context('fork') # ocrd.Processor will fork workers
|
||||
self.results = ctxt.Manager().dict()
|
||||
self.closable = ctxt.Manager().list()
|
||||
super().__init__(name="EynollahPredictor", daemon=True)
|
||||
|
||||
|
|
@ -57,15 +57,20 @@ class Predictor(mp.context.SpawnProcess):
|
|||
return Predictor.SingleModelPredictor(self, model)
|
||||
|
||||
def __call__(self, model: str, data: dict):
|
||||
with self.jobid.get_lock():
|
||||
# unusable as per python/cpython#79967
|
||||
#with self.jobid.get_lock():
|
||||
# would work, but not public:
|
||||
#with self.jobid._mutex:
|
||||
with self.joblock:
|
||||
self.jobid.value += 1
|
||||
jobid = self.jobid.value
|
||||
if not len(data):
|
||||
self.taskq.put((jobid, model, data))
|
||||
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, model)
|
||||
return self.result(jobid)
|
||||
with share_ndarray(data) as shared_data:
|
||||
self.taskq.put((jobid, model, shared_data))
|
||||
#self.logger.debug("sent task '%d'", jobid)
|
||||
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, model, shared_data)
|
||||
return self.result(jobid)
|
||||
|
||||
def result(self, jobid):
|
||||
|
|
@ -81,11 +86,11 @@ class Predictor(mp.context.SpawnProcess):
|
|||
self.closable.append(jobid)
|
||||
return result
|
||||
try:
|
||||
jobid, result = self.resultq.get(timeout=0.7)
|
||||
jobid0, result = self.resultq.get(timeout=0.7)
|
||||
except mp.queues.Empty:
|
||||
continue
|
||||
#self.logger.debug("storing results for '%d'", jobid)
|
||||
self.results[jobid] = result
|
||||
#self.logger.debug("storing results for '%d': '%s'", jobid0, result)
|
||||
self.results[jobid0] = result
|
||||
raise Exception(f"predictor terminated while waiting on results for {jobid}")
|
||||
|
||||
def run(self):
|
||||
|
|
@ -106,13 +111,13 @@ class Predictor(mp.context.SpawnProcess):
|
|||
jobid, model, shared_data = self.taskq.get(timeout=1.1)
|
||||
except mp.queues.Empty:
|
||||
continue
|
||||
#self.logger.debug("predicting '%d'", jobid)
|
||||
try:
|
||||
model = self.model_zoo.get(model)
|
||||
if not len(shared_data):
|
||||
# non-input msg: model query
|
||||
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, model)
|
||||
result = model.output_shape
|
||||
else:
|
||||
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
|
||||
with ndarray_shared(shared_data) as data:
|
||||
result = model.predict(data, verbose=0)
|
||||
#self.logger.debug("sharing result array for '%d'", jobid)
|
||||
|
|
@ -122,22 +127,26 @@ class Predictor(mp.context.SpawnProcess):
|
|||
result = stack.enter_context(share_ndarray(result))
|
||||
closing[jobid] = stack.pop_all()
|
||||
except Exception as e:
|
||||
self.logger.error("prediction failed: %s", e.__class__.__name__)
|
||||
self.logger.error("prediction '%d' failed: %s", jobid, e.__class__.__name__)
|
||||
result = e
|
||||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d'", jobid)
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
close_all()
|
||||
#self.logger.debug("predictor terminated")
|
||||
|
||||
def load_models(self, *loadable: List[str]):
|
||||
self.loadable = loadable
|
||||
self.start()
|
||||
self.start() # call run() in subprocess
|
||||
# parent context here
|
||||
del self.model_zoo
|
||||
del self.model_zoo # only in subprocess
|
||||
ctxt = mp.get_context('fork') # ocrd.Processor will fork workers
|
||||
mngr = ctxt.Manager()
|
||||
self.jobid = mngr.Value('i', 0)
|
||||
self.joblock = mngr.Lock()
|
||||
self.results = mngr.dict()
|
||||
|
||||
def setup(self):
|
||||
logging.root.handlers = [logging.handlers.QueueHandler(self.logq)]
|
||||
self.logger.setLevel(self.loglevel)
|
||||
self.model_zoo.load_models(*self.loadable)
|
||||
|
||||
def shutdown(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue