predictor: fix spawn vs fork / parent vs child contexts

This commit is contained in:
Robert Sachunsky 2026-03-13 02:42:16 +01:00
parent 64281768a9
commit 800c55b826
2 changed files with 31 additions and 21 deletions

View file

@ -190,8 +190,9 @@ class Eynollah:
self.model_zoo.load_models(*loadable) self.model_zoo.load_models(*loadable)
for model in loadable: for model in loadable:
# cache and retrieve output shapes # retrieve and cache output shapes
self.model_zoo.get(model).output_shape self.logger.debug("model %s has output shape %s", model,
self.model_zoo.get(model).output_shape)
def __del__(self): def __del__(self):
if executor := getattr(self, 'executor', None): if executor := getattr(self, 'executor', None):

View file

@ -36,19 +36,19 @@ class Predictor(mp.context.SpawnProcess):
def __init__(self, logger, model_zoo): def __init__(self, logger, model_zoo):
self.logger = logger self.logger = logger
self.loglevel = logger.level
self.model_zoo = model_zoo self.model_zoo = model_zoo
ctxt = mp.get_context('spawn') ctxt = mp.get_context('spawn')
self.jobid = ctxt.Value('i', 0)
self.taskq = ctxt.Queue(maxsize=QSIZE) self.taskq = ctxt.Queue(maxsize=QSIZE)
self.resultq = ctxt.Queue(maxsize=QSIZE) self.resultq = ctxt.Queue(maxsize=QSIZE)
self.logq = ctxt.Queue(maxsize=QSIZE * 100) self.logq = ctxt.Queue(maxsize=QSIZE * 100)
log_listener = logging.handlers.QueueListener( logging.handlers.QueueListener(
self.logq, *self.logger.handlers, self.logq, *(
respect_handler_level=True).start() # as per ocrd_utils.initLogging():
logging.root.handlers +
# as per eynollah_cli.main():
self.logger.handlers
), respect_handler_level=False).start()
self.stopped = ctxt.Event() self.stopped = ctxt.Event()
ctxt = mp.get_context('fork') # ocrd.Processor will fork workers
self.results = ctxt.Manager().dict()
self.closable = ctxt.Manager().list() self.closable = ctxt.Manager().list()
super().__init__(name="EynollahPredictor", daemon=True) super().__init__(name="EynollahPredictor", daemon=True)
@ -57,15 +57,20 @@ class Predictor(mp.context.SpawnProcess):
return Predictor.SingleModelPredictor(self, model) return Predictor.SingleModelPredictor(self, model)
def __call__(self, model: str, data: dict): def __call__(self, model: str, data: dict):
with self.jobid.get_lock(): # unusable as per python/cpython#79967
#with self.jobid.get_lock():
# would work, but not public:
#with self.jobid._mutex:
with self.joblock:
self.jobid.value += 1 self.jobid.value += 1
jobid = self.jobid.value jobid = self.jobid.value
if not len(data): if not len(data):
self.taskq.put((jobid, model, data)) self.taskq.put((jobid, model, data))
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, model)
return self.result(jobid) return self.result(jobid)
with share_ndarray(data) as shared_data: with share_ndarray(data) as shared_data:
self.taskq.put((jobid, model, shared_data)) self.taskq.put((jobid, model, shared_data))
#self.logger.debug("sent task '%d'", jobid) #self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, model, shared_data)
return self.result(jobid) return self.result(jobid)
def result(self, jobid): def result(self, jobid):
@ -81,11 +86,11 @@ class Predictor(mp.context.SpawnProcess):
self.closable.append(jobid) self.closable.append(jobid)
return result return result
try: try:
jobid, result = self.resultq.get(timeout=0.7) jobid0, result = self.resultq.get(timeout=0.7)
except mp.queues.Empty: except mp.queues.Empty:
continue continue
#self.logger.debug("storing results for '%d'", jobid) #self.logger.debug("storing results for '%d': '%s'", jobid0, result)
self.results[jobid] = result self.results[jobid0] = result
raise Exception(f"predictor terminated while waiting on results for {jobid}") raise Exception(f"predictor terminated while waiting on results for {jobid}")
def run(self): def run(self):
@ -106,13 +111,13 @@ class Predictor(mp.context.SpawnProcess):
jobid, model, shared_data = self.taskq.get(timeout=1.1) jobid, model, shared_data = self.taskq.get(timeout=1.1)
except mp.queues.Empty: except mp.queues.Empty:
continue continue
#self.logger.debug("predicting '%d'", jobid)
try: try:
model = self.model_zoo.get(model) model = self.model_zoo.get(model)
if not len(shared_data): if not len(shared_data):
# non-input msg: model query #self.logger.debug("getting '%d' output shape of model '%s'", jobid, model)
result = model.output_shape result = model.output_shape
else: else:
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
with ndarray_shared(shared_data) as data: with ndarray_shared(shared_data) as data:
result = model.predict(data, verbose=0) result = model.predict(data, verbose=0)
#self.logger.debug("sharing result array for '%d'", jobid) #self.logger.debug("sharing result array for '%d'", jobid)
@ -122,22 +127,26 @@ class Predictor(mp.context.SpawnProcess):
result = stack.enter_context(share_ndarray(result)) result = stack.enter_context(share_ndarray(result))
closing[jobid] = stack.pop_all() closing[jobid] = stack.pop_all()
except Exception as e: except Exception as e:
self.logger.error("prediction failed: %s", e.__class__.__name__) self.logger.error("prediction '%d' failed: %s", jobid, e.__class__.__name__)
result = e result = e
self.resultq.put((jobid, result)) self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d'", jobid) #self.logger.debug("sent result for '%d': %s", jobid, result)
close_all() close_all()
#self.logger.debug("predictor terminated") #self.logger.debug("predictor terminated")
def load_models(self, *loadable: List[str]): def load_models(self, *loadable: List[str]):
self.loadable = loadable self.loadable = loadable
self.start() self.start() # call run() in subprocess
# parent context here # parent context here
del self.model_zoo del self.model_zoo # only in subprocess
ctxt = mp.get_context('fork') # ocrd.Processor will fork workers
mngr = ctxt.Manager()
self.jobid = mngr.Value('i', 0)
self.joblock = mngr.Lock()
self.results = mngr.dict()
def setup(self): def setup(self):
logging.root.handlers = [logging.handlers.QueueHandler(self.logq)] logging.root.handlers = [logging.handlers.QueueHandler(self.logq)]
self.logger.setLevel(self.loglevel)
self.model_zoo.load_models(*self.loadable) self.model_zoo.load_models(*self.loadable)
def shutdown(self): def shutdown(self):