diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 1f7d585..f3229d9 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -35,22 +35,15 @@ import numpy as np import shapely.affinity from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter1d -from ocrd_utils import tf_disable_interactive_logs import statistics -tf_disable_interactive_logs() - -import tensorflow as tf -try: - import torch -except ImportError: - torch = None try: import matplotlib.pyplot as plt except ImportError: plt = None from .model_zoo import EynollahModelZoo +from .predictor import Predictor from .utils.contour import ( filter_contours_area_of_image, filter_contours_area_of_image_tables, @@ -129,7 +122,7 @@ class Eynollah: logger : Optional[logging.Logger] = None, ): self.logger = logger or logging.getLogger('eynollah') - self.model_zoo = model_zoo + self.model_zoo = Predictor(self.logger, model_zoo) self.plotter = None self.reading_order_machine_based = reading_order_machine_based @@ -169,12 +162,6 @@ class Eynollah: t_start = time.time() - try: - for device in tf.config.list_physical_devices('GPU'): - tf.config.experimental.set_memory_growth(device, True) - except: - self.logger.warning("no GPU device available") - self.logger.info("Loading models...") self.setup_models() self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") @@ -199,26 +186,21 @@ class Eynollah: if self.reading_order_machine_based: loadable.append("reading_order") if self.tables: - loadable.append(("table")) + loadable.append("table") self.model_zoo.load_models(*loadable) + for model in loadable: + # cache and retrieve output shapes + self.model_zoo.get(model).output_shape def __del__(self): - if hasattr(self, 'executor') and getattr(self, 'executor'): - assert self.executor - self.executor.shutdown() - self.executor = None - self.model_zoo.shutdown() - - @property - def device(self): - # TODO why here and why only for tr? - assert torch - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - return torch.device("cuda:0") - self.logger.info("Using CPU processing") - return torch.device("cpu") + if executor := getattr(self, 'executor', None): + executor.shutdown() + del self.executor + if model_zoo := getattr(self, 'model_zoo', None): + if shutdown := getattr(model_zoo, 'shutdown', None): + shutdown() + del self.model_zoo def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} @@ -535,8 +517,7 @@ class Eynollah: ): self.logger.debug("enter do_prediction (patches=%d)", patches) - img_height_model = model.layers[-1].output_shape[1] - img_width_model = model.layers[-1].output_shape[2] + _, img_height_model, img_width_model, _ = model.output_shape img_h_page = img.shape[0] img_w_page = img.shape[1] @@ -736,8 +717,7 @@ class Eynollah: ): self.logger.debug("enter do_prediction_new_concept (patches=%d)", patches) - img_height_model = model.layers[-1].output_shape[1] - img_width_model = model.layers[-1].output_shape[2] + _, img_height_model, img_width_model, _ = model.output_shape img = img / 255.0 img = img.astype(np.float16) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index ca5de05..8638b65 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -90,6 +90,17 @@ class EynollahModelZoo: """ Load all models by calling load_model and return a dictionary mapping model_category to loaded model """ + import tensorflow as tf + cuda = False + try: + for device in tf.config.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(device, True) + cuda = True + self.logger.info("using GPU %s", device.name) + except RuntimeError: + self.logger.exception("cannot configure GPU devices") + if not cuda: + self.logger.warning("no GPU device available") ret = {} for load_args in all_load_args: if isinstance(load_args, str): diff --git a/src/eynollah/predictor.py b/src/eynollah/predictor.py new file mode 100644 index 0000000..2ce24e5 --- /dev/null +++ b/src/eynollah/predictor.py @@ -0,0 +1,154 @@ +import threading +from contextlib import ExitStack +from functools import lru_cache +from typing import List +import logging +import logging.handlers +import multiprocessing as mp +import numpy as np + +from .utils.shm import share_ndarray, ndarray_shared + +QSIZE = 200 + + +class Predictor(mp.context.SpawnProcess): + """ + singleton subprocess solely responsible for prediction with TensorFlow, + communicates with any number of worker processes, + acts as a shallow replacement for EynollahModelZoo + """ + class SingleModelPredictor: + """ + acts as a shallow replacement for EynollahModelZoo + """ + def __init__(self, predictor: 'Predictor', model: str): + self.predictor = predictor + self.model = model + @property + def name(self): + return self.model + @property + def output_shape(self): + return self.predictor(self.model, {}) + def predict(self, data: dict, verbose=0): + return self.predictor(self.model, data) + + def __init__(self, logger, model_zoo): + self.logger = logger + self.loglevel = logger.level + self.model_zoo = model_zoo + ctxt = mp.get_context('spawn') + self.jobid = mp.Value('i', 0) + self.taskq = ctxt.Queue(maxsize=QSIZE) + self.resultq = ctxt.Queue(maxsize=QSIZE) + self.logq = ctxt.Queue(maxsize=QSIZE * 100) + log_listener = logging.handlers.QueueListener( + self.logq, *self.logger.handlers, + respect_handler_level=True).start() + self.terminate = ctxt.Event() + ctxt = mp.get_context('fork') # ocrd.Processor will fork workers + self.results = ctxt.Manager().dict() + self.closable = ctxt.Manager().list() + super().__init__(name="EynollahPredictor", daemon=True) + + @lru_cache + def get(self, model: str): + return Predictor.SingleModelPredictor(self, model) + + def __call__(self, model: str, data: dict): + with self.jobid.get_lock(): + self.jobid.value += 1 + jobid = self.jobid.value + if not len(data): + self.taskq.put((jobid, model, data)) + return self.result(jobid) + with share_ndarray(data) as shared_data: + self.taskq.put((jobid, model, shared_data)) + #self.logger.debug("sent task '%d'", jobid) + return self.result(jobid) + + def result(self, jobid): + while not self.terminate.is_set(): + if jobid in self.results: + #self.logger.debug("received result for '%d'", jobid) + result = self.results.pop(jobid) + if isinstance(result, Exception): + raise Exception(f"predictor failed for {jobid}") from result + elif isinstance(result, dict): + with ndarray_shared(result) as shared_result: + result = np.copy(shared_result) + self.closable.append(jobid) + return result + try: + jobid, result = self.resultq.get(timeout=0.7) + except mp.queues.Empty: + continue + #self.logger.debug("storing results for '%d'", jobid) + self.results[jobid] = result + raise Exception(f"predictor terminated while waiting on results for {jobid}") + + def run(self): + try: + self.setup() # fill model_zoo etc + except Exception as e: + self.logger.exception("setup failed") + self.terminate.set() + closing = {} + while not self.terminate.is_set(): + for jobid in list(self.closable): + self.closable.remove(jobid) + closing.pop(jobid).close() + #self.logger.debug("closed shm for '%d'", jobid) + try: + jobid, model, shared_data = self.taskq.get(timeout=1.1) + except mp.queues.Empty: + continue + #self.logger.debug("predicting '%d'", jobid) + try: + model = self.model_zoo.get(model) + if not len(shared_data): + # non-input msg: model query + result = model.output_shape + else: + with ndarray_shared(shared_data) as data: + result = model.predict(data, verbose=0) + #self.logger.debug("sharing result array for '%d'", jobid) + with ExitStack() as stack: + # we don't know when the result will be received, + # but don't want to wait either, so + result = stack.enter_context(share_ndarray(result)) + closing[jobid] = stack.pop_all() + except Exception as e: + self.logger.error("prediction failed: %s", e.__class__.__name__) + result = e + self.resultq.put((jobid, result)) + #self.logger.debug("sent result for '%d'", jobid) + self.resultq.close() + self.resultq.cancel_join_thread() + #self.logger.debug("predictor terminated") + + def load_models(self, *loadable: List[str]): + self.loadable = loadable + self.start() + + def setup(self): + logging.root.handlers = [logging.handlers.QueueHandler(self.logq)] + self.logger.setLevel(self.loglevel) + self.model_zoo.load_models(*self.loadable) + + def shutdown(self): + # do not terminate from forked processor instances + if mp.parent_process() is None: + self.terminate.set() + self.taskq.close() + self.taskq.cancel_join_thread() + #self.logger.debug(f"terminated {self} in {mp.current_process().name}") + else: + pass + #self.logger.debug(f"not touching {self} in {mp.current_process().name}") + self.model_zoo.shutdown() + + def __del__(self): + #self.logger.debug(f"deinit of {self} in {mp.current_process().name}") + self.shutdown()