mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-24 16:12:03 +01:00
predictor: fix spawn vs fork / parent vs child contexts
This commit is contained in:
parent
64281768a9
commit
800c55b826
2 changed files with 31 additions and 21 deletions
|
|
@ -190,8 +190,9 @@ class Eynollah:
|
||||||
|
|
||||||
self.model_zoo.load_models(*loadable)
|
self.model_zoo.load_models(*loadable)
|
||||||
for model in loadable:
|
for model in loadable:
|
||||||
# cache and retrieve output shapes
|
# retrieve and cache output shapes
|
||||||
self.model_zoo.get(model).output_shape
|
self.logger.debug("model %s has output shape %s", model,
|
||||||
|
self.model_zoo.get(model).output_shape)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if executor := getattr(self, 'executor', None):
|
if executor := getattr(self, 'executor', None):
|
||||||
|
|
|
||||||
|
|
@ -36,19 +36,19 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
|
|
||||||
def __init__(self, logger, model_zoo):
|
def __init__(self, logger, model_zoo):
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.loglevel = logger.level
|
|
||||||
self.model_zoo = model_zoo
|
self.model_zoo = model_zoo
|
||||||
ctxt = mp.get_context('spawn')
|
ctxt = mp.get_context('spawn')
|
||||||
self.jobid = ctxt.Value('i', 0)
|
|
||||||
self.taskq = ctxt.Queue(maxsize=QSIZE)
|
self.taskq = ctxt.Queue(maxsize=QSIZE)
|
||||||
self.resultq = ctxt.Queue(maxsize=QSIZE)
|
self.resultq = ctxt.Queue(maxsize=QSIZE)
|
||||||
self.logq = ctxt.Queue(maxsize=QSIZE * 100)
|
self.logq = ctxt.Queue(maxsize=QSIZE * 100)
|
||||||
log_listener = logging.handlers.QueueListener(
|
logging.handlers.QueueListener(
|
||||||
self.logq, *self.logger.handlers,
|
self.logq, *(
|
||||||
respect_handler_level=True).start()
|
# as per ocrd_utils.initLogging():
|
||||||
|
logging.root.handlers +
|
||||||
|
# as per eynollah_cli.main():
|
||||||
|
self.logger.handlers
|
||||||
|
), respect_handler_level=False).start()
|
||||||
self.stopped = ctxt.Event()
|
self.stopped = ctxt.Event()
|
||||||
ctxt = mp.get_context('fork') # ocrd.Processor will fork workers
|
|
||||||
self.results = ctxt.Manager().dict()
|
|
||||||
self.closable = ctxt.Manager().list()
|
self.closable = ctxt.Manager().list()
|
||||||
super().__init__(name="EynollahPredictor", daemon=True)
|
super().__init__(name="EynollahPredictor", daemon=True)
|
||||||
|
|
||||||
|
|
@ -57,15 +57,20 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
return Predictor.SingleModelPredictor(self, model)
|
return Predictor.SingleModelPredictor(self, model)
|
||||||
|
|
||||||
def __call__(self, model: str, data: dict):
|
def __call__(self, model: str, data: dict):
|
||||||
with self.jobid.get_lock():
|
# unusable as per python/cpython#79967
|
||||||
|
#with self.jobid.get_lock():
|
||||||
|
# would work, but not public:
|
||||||
|
#with self.jobid._mutex:
|
||||||
|
with self.joblock:
|
||||||
self.jobid.value += 1
|
self.jobid.value += 1
|
||||||
jobid = self.jobid.value
|
jobid = self.jobid.value
|
||||||
if not len(data):
|
if not len(data):
|
||||||
self.taskq.put((jobid, model, data))
|
self.taskq.put((jobid, model, data))
|
||||||
|
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, model)
|
||||||
return self.result(jobid)
|
return self.result(jobid)
|
||||||
with share_ndarray(data) as shared_data:
|
with share_ndarray(data) as shared_data:
|
||||||
self.taskq.put((jobid, model, shared_data))
|
self.taskq.put((jobid, model, shared_data))
|
||||||
#self.logger.debug("sent task '%d'", jobid)
|
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, model, shared_data)
|
||||||
return self.result(jobid)
|
return self.result(jobid)
|
||||||
|
|
||||||
def result(self, jobid):
|
def result(self, jobid):
|
||||||
|
|
@ -81,11 +86,11 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
self.closable.append(jobid)
|
self.closable.append(jobid)
|
||||||
return result
|
return result
|
||||||
try:
|
try:
|
||||||
jobid, result = self.resultq.get(timeout=0.7)
|
jobid0, result = self.resultq.get(timeout=0.7)
|
||||||
except mp.queues.Empty:
|
except mp.queues.Empty:
|
||||||
continue
|
continue
|
||||||
#self.logger.debug("storing results for '%d'", jobid)
|
#self.logger.debug("storing results for '%d': '%s'", jobid0, result)
|
||||||
self.results[jobid] = result
|
self.results[jobid0] = result
|
||||||
raise Exception(f"predictor terminated while waiting on results for {jobid}")
|
raise Exception(f"predictor terminated while waiting on results for {jobid}")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
@ -106,13 +111,13 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
jobid, model, shared_data = self.taskq.get(timeout=1.1)
|
jobid, model, shared_data = self.taskq.get(timeout=1.1)
|
||||||
except mp.queues.Empty:
|
except mp.queues.Empty:
|
||||||
continue
|
continue
|
||||||
#self.logger.debug("predicting '%d'", jobid)
|
|
||||||
try:
|
try:
|
||||||
model = self.model_zoo.get(model)
|
model = self.model_zoo.get(model)
|
||||||
if not len(shared_data):
|
if not len(shared_data):
|
||||||
# non-input msg: model query
|
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, model)
|
||||||
result = model.output_shape
|
result = model.output_shape
|
||||||
else:
|
else:
|
||||||
|
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
|
||||||
with ndarray_shared(shared_data) as data:
|
with ndarray_shared(shared_data) as data:
|
||||||
result = model.predict(data, verbose=0)
|
result = model.predict(data, verbose=0)
|
||||||
#self.logger.debug("sharing result array for '%d'", jobid)
|
#self.logger.debug("sharing result array for '%d'", jobid)
|
||||||
|
|
@ -122,22 +127,26 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
result = stack.enter_context(share_ndarray(result))
|
result = stack.enter_context(share_ndarray(result))
|
||||||
closing[jobid] = stack.pop_all()
|
closing[jobid] = stack.pop_all()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error("prediction failed: %s", e.__class__.__name__)
|
self.logger.error("prediction '%d' failed: %s", jobid, e.__class__.__name__)
|
||||||
result = e
|
result = e
|
||||||
self.resultq.put((jobid, result))
|
self.resultq.put((jobid, result))
|
||||||
#self.logger.debug("sent result for '%d'", jobid)
|
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||||
close_all()
|
close_all()
|
||||||
#self.logger.debug("predictor terminated")
|
#self.logger.debug("predictor terminated")
|
||||||
|
|
||||||
def load_models(self, *loadable: List[str]):
|
def load_models(self, *loadable: List[str]):
|
||||||
self.loadable = loadable
|
self.loadable = loadable
|
||||||
self.start()
|
self.start() # call run() in subprocess
|
||||||
# parent context here
|
# parent context here
|
||||||
del self.model_zoo
|
del self.model_zoo # only in subprocess
|
||||||
|
ctxt = mp.get_context('fork') # ocrd.Processor will fork workers
|
||||||
|
mngr = ctxt.Manager()
|
||||||
|
self.jobid = mngr.Value('i', 0)
|
||||||
|
self.joblock = mngr.Lock()
|
||||||
|
self.results = mngr.dict()
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logging.root.handlers = [logging.handlers.QueueHandler(self.logq)]
|
logging.root.handlers = [logging.handlers.QueueHandler(self.logq)]
|
||||||
self.logger.setLevel(self.loglevel)
|
|
||||||
self.model_zoo.load_models(*self.loadable)
|
self.model_zoo.load_models(*self.loadable)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue