diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index db3d9d2..0cad343 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -44,7 +44,6 @@ except ImportError: plt = None from .model_zoo import EynollahModelZoo -from .predictor import Predictor from .utils.contour import ( filter_contours_area_of_image, filter_contours_area_of_image_tables, @@ -142,7 +141,7 @@ class Eynollah: logger : Optional[logging.Logger] = None, ): self.logger = logger or logging.getLogger('eynollah') - self.model_zoo = Predictor(self.logger, model_zoo) + self.model_zoo = model_zoo self.plotter = None self.reading_order_machine_based = reading_order_machine_based @@ -174,26 +173,34 @@ class Eynollah: # load models, depending on modes # (note: loading too many models can cause OOM on GPU/CUDA, # thus, we try set up the minimal configuration for the current mode) + # autosized variants: _resized or _patched (which one may depend on num_cols) + # (but _resized for full page images is too slow - better resize on CPU in numpy) loadable = [ "col_classifier", - "binarization", - #"enhancement", + #"enhancement", # todo: enhancement_patched "page", #"region" ] - loadable.append(("textline")) + if self.input_binary: + loadable.append("binarization") # todo: binarization_patched + loadable.append("textline_patched") # textline loadable.append("region_1_2") + loadable.append("region_1_2_patched") if self.full_layout: loadable.append("region_fl_np") - #loadable.append("region_fl") + #loadable.append("region_fl_patched") if self.reading_order_machine_based: - loadable.append("reading_order") + loadable.append("reading_order") # todo: reading_order_patched if self.tables: loadable.append("table") self.model_zoo.load_models(*loadable) for model in loadable: # retrieve and cache output shapes + if model.endswith(('_resized', '_patched')): + # autosized models do not have a predefined output_shape + # (and don't need one) + continue self.logger.debug("model %s has output shape %s", model, self.model_zoo.get(model).output_shape) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 8060006..92fedfb 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -5,21 +5,9 @@ from copy import deepcopy from pathlib import Path from typing import Dict, List, Optional, Tuple, Type, Union -os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 -from ocrd_utils import tf_disable_interactive_logs -tf_disable_interactive_logs() - -from tensorflow.keras.layers import StringLookup -from tensorflow.keras.models import Model as KerasModel -from tensorflow.keras.models import load_model from tabulate import tabulate -from ..patch_encoder import ( - PatchEncoder, - Patches, - wrap_layout_model_patched, - wrap_layout_model_resized, -) +from ..predictor import Predictor from .specs import EynollahModelSpecSet from .default_specs import DEFAULT_MODEL_SPECS from .types import AnyModel, T @@ -46,7 +34,7 @@ class EynollahModelZoo: self._overrides = [] if model_overrides: self.override_models(*model_overrides) - self._loaded: Dict[str, AnyModel] = {} + self._loaded: Dict[str, Predictor] = {} @property def model_overrides(self): @@ -90,34 +78,60 @@ class EynollahModelZoo: """ Load all models by calling load_model and return a dictionary mapping model_category to loaded model """ - import tensorflow as tf - cuda = False - try: - for device in tf.config.list_physical_devices('GPU'): - tf.config.experimental.set_memory_growth(device, True) - cuda = True - self.logger.info("using GPU %s", device.name) - except RuntimeError: - self.logger.exception("cannot configure GPU devices") - if not cuda: - self.logger.warning("no GPU device available") - ret = {} + ret = {} # cannot use self._loaded here, yet – first spawn all predictors for load_args in all_load_args: if isinstance(load_args, str): - ret[load_args] = self.load_model(load_args) + model_category = load_args + load_args = [model_category] else: - ret[load_args[0]] = self.load_model(*load_args) - return ret + model_category = load_args[0] + load_kwargs = {} + if model_category.endswith('_resized'): + load_args[0] = model_category[:-8] + load_kwargs["resized"] = True + elif model_category.endswith('_patched'): + load_args[0] = model_category[:-8] + load_kwargs["patched"] = True + ret[model_category] = Predictor(self.logger, self) + ret[model_category].load_model(*load_args, **load_kwargs) + self._loaded.update(ret) + return self._loaded def load_model( self, model_category: str, model_variant: str = '', model_path_override: Optional[str] = None, + patched: bool = False, + resized: bool = False, ) -> AnyModel: """ Load any model """ + os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 + from ocrd_utils import tf_disable_interactive_logs + tf_disable_interactive_logs() + + import tensorflow as tf + from tensorflow.keras.models import load_model + + from ..patch_encoder import ( + PatchEncoder, + Patches, + wrap_layout_model_patched, + wrap_layout_model_resized, + ) + cuda = False + try: + device = tf.config.list_physical_devices('GPU')[0] + tf.config.experimental.set_memory_growth(device, True) + cuda = True + self.logger.info("using GPU %s", device.name) + except RuntimeError: + self.logger.exception("cannot configure GPU devices") + if not cuda: + self.logger.warning("no GPU device available") + if model_path_override: self.override_models((model_category, model_variant, model_path_override)) model_path = self.model_path(model_category, model_variant) @@ -142,26 +156,26 @@ class EynollahModelZoo: model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches} ) model._name = model_category - self._loaded[model_category] = model - # autosized for full page images is too slow (better resize on CPU in numpy): - # if model_category in ['region_1_2', 'table', 'region_fl_np']: - # self._loaded[model_category + '_resized'] = wrap_layout_model_resized(model) - if model_category in ['region_1_2', 'textline']: - self._loaded[model_category + '_patched'] = wrap_layout_model_patched(model) - return model # type: ignore + if resized: + model = wrap_layout_model_resized(model) + model._name = model_category + '_resized' + elif patched: + model = wrap_layout_model_patched(model) + model._name = model_category + '_patched' + return model - def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T: + def get(self, model_category: str) -> Predictor: if model_category not in self._loaded: - raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"') - ret = self._loaded[model_category] - if model_type: - assert isinstance(ret, model_type) - return ret # type: ignore # FIXME: convince typing that we're returning generic type + raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"') + return self._loaded[model_category] def _load_ocr_model(self, variant: str) -> AnyModel: """ Load OCR model """ + from tensorflow.keras.models import Model as KerasModel + from tensorflow.keras.models import load_model + ocr_model_dir = self.model_path('ocr', variant) if variant == 'tr': from transformers import VisionEncoderDecoderModel @@ -183,10 +197,12 @@ class EynollahModelZoo: with open(self.model_path('num_to_char'), "r") as config_file: return json.load(config_file) - def _load_num_to_char(self) -> StringLookup: + def _load_num_to_char(self) -> 'StringLookup': """ Load decoder for OCR """ + from tensorflow.keras.layers import StringLookup + characters = self._load_characters() # Mapping characters to integers. char_to_num = StringLookup(vocabulary=characters, mask_token=None) @@ -225,4 +241,5 @@ class EynollahModelZoo: """ if hasattr(self, '_loaded') and getattr(self, '_loaded'): for needle in list(self._loaded.keys()): + self._loaded[needle].shutdown() del self._loaded[needle] diff --git a/src/eynollah/predictor.py b/src/eynollah/predictor.py index 8b46250..d6e149c 100644 --- a/src/eynollah/predictor.py +++ b/src/eynollah/predictor.py @@ -1,7 +1,5 @@ -import threading from contextlib import ExitStack -from functools import lru_cache -from typing import List +from typing import List, Dict import logging import logging.handlers import multiprocessing as mp @@ -16,27 +14,12 @@ class Predictor(mp.context.SpawnProcess): """ singleton subprocess solely responsible for prediction with TensorFlow, communicates with any number of worker processes, - acts as a shallow replacement for EynollahModelZoo + acting as a shallow replacement for various model types in EynollahModelZoo's + _loaded dict for each single model """ - class SingleModelPredictor: - """ - acts as a shallow replacement for EynollahModelZoo - """ - def __init__(self, predictor: 'Predictor', model: str): - self.predictor = predictor - self.model = model - @property - def name(self): - return self.model - @property - def output_shape(self): - return self.predictor(self.model, {}) - def predict(self, data: dict, verbose=0): - return self.predictor(self.model, data) - def __init__(self, logger, model_zoo): self.logger = logger - self.loglevel = logger.level + self.loglevel = logger.parent.level self.model_zoo = model_zoo ctxt = mp.get_context('spawn') self.taskq = ctxt.Queue(maxsize=QSIZE) @@ -47,17 +30,20 @@ class Predictor(mp.context.SpawnProcess): # as per ocrd_utils.initLogging(): logging.root.handlers + # as per eynollah_cli.main(): - self.logger.handlers + self.logger.parent.handlers ), respect_handler_level=False).start() self.stopped = ctxt.Event() self.closable = ctxt.Manager().list() super().__init__(name="EynollahPredictor", daemon=True) - @lru_cache - def get(self, model: str): - return Predictor.SingleModelPredictor(self, model) + @property + def output_shape(self): + return self({}) - def __call__(self, model: str, data: dict): + def predict(self, data: dict, verbose=0): + return self(data) + + def __call__(self, data: dict): # unusable as per python/cpython#79967 #with self.jobid.get_lock(): # would work, but not public: @@ -66,12 +52,12 @@ class Predictor(mp.context.SpawnProcess): self.jobid.value += 1 jobid = self.jobid.value if not len(data): - self.taskq.put((jobid, model, data)) - #self.logger.debug("sent shape query task '%d' for model '%s'", jobid, model) + self.taskq.put((jobid, data)) + #self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name) return self.result(jobid) with share_ndarray(data) as shared_data: - self.taskq.put((jobid, model, shared_data)) - #self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, model, shared_data) + self.taskq.put((jobid, shared_data)) + #self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data) return self.result(jobid) def result(self, jobid): @@ -80,7 +66,7 @@ class Predictor(mp.context.SpawnProcess): #self.logger.debug("received result for '%d'", jobid) result = self.results.pop(jobid) if isinstance(result, Exception): - raise Exception(f"predictor failed for {jobid}") from result + raise Exception(f"predictor {self.name} failed for {jobid}") from result elif isinstance(result, dict): with ndarray_shared(result) as shared_result: result = np.copy(shared_result) @@ -92,7 +78,7 @@ class Predictor(mp.context.SpawnProcess): continue #self.logger.debug("storing results for '%d': '%s'", jobid0, result) self.results[jobid0] = result - raise Exception(f"predictor terminated while waiting on results for {jobid}") + raise Exception(f"predictor {self.name} terminated while waiting on results for {jobid}") def run(self): try: @@ -100,6 +86,7 @@ class Predictor(mp.context.SpawnProcess): except Exception as e: self.logger.exception("setup failed") self.stopped.set() + return closing = {} def close_all(): for jobid in list(self.closable): @@ -110,63 +97,58 @@ class Predictor(mp.context.SpawnProcess): close_all() try: TIMEOUT = 4.5 # 1.1 too is greedy - not enough for rebatching - jobid, model, shared_data = self.taskq.get(timeout=TIMEOUT) + jobid, shared_data = self.taskq.get(timeout=TIMEOUT) except mp.queues.Empty: continue try: # up to what batch size fits into small (8GB) VRAM? # (notice we are not listing _resized/_patched models here, - # because here inputs/outputs will have varying shapes) + # because its inputs/outputs will have varying shapes) REBATCH_SIZE = { # small models (448x448)... - "col_classifier": 4, - "page": 4, - "binarization": 5, - "enhancement": 5, - "reading_order": 5, + "col_classifier": 2, + "page": 2, + "binarization": 4, + "enhancement": 4, + "reading_order": 4, # medium size (672x672)... - "textline": 3, + "textline": 2, # large models... - "table": 2, - "region_1_2": 2, - "region_fl_np": 2, - "region_fl": 2, - }.get(model, 1) - loaded_model = self.model_zoo.get(model) + "table": 1, + "region_1_2": 1, + "region_fl_np": 1, + "region_fl": 1, + }.get(self.name, 1) if not len(shared_data): - #self.logger.debug("getting '%d' output shape of model '%s'", jobid, model) - result = loaded_model.output_shape + #self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name) + result = self.model.output_shape self.resultq.put((jobid, result)) #self.logger.debug("sent result for '%d': %s", jobid, result) else: - other_tasks = [] # other model, put back on q - model_tasks = [] # same model, for rebatching - model_tasks.append((jobid, shared_data)) + tasks = [(jobid, shared_data)] batch_size = shared_data['shape'][0] while (not self.taskq.empty() and # climb to target batch size - batch_size * len(model_tasks) < REBATCH_SIZE): - jobid0, model0, shared_data0 = self.taskq.get() - if model0 == model and len(shared_data0): + batch_size * len(tasks) < REBATCH_SIZE): + jobid0, shared_data0 = self.taskq.get() + if len(shared_data0): # add to our batch - model_tasks.append((jobid0, shared_data0)) + tasks.append((jobid0, shared_data0)) else: - other_tasks.append((jobid0, model0, shared_data0)) - if len(other_tasks): - self.logger.debug("requeuing %d other tasks", len(other_tasks)) - for task in other_tasks: - self.taskq.put(task) - if len(model_tasks) > 1: - self.logger.debug("rebatching %d %s tasks of batch size %d", len(model_tasks), model, batch_size) + # immediately anser + self.resultq.put((jobid0, self.model.output_shape)) + if len(tasks) > 1: + self.logger.debug("rebatching %d '%s' tasks of batch size %d", + len(tasks), self.name, batch_size) with ExitStack() as stack: data = [] jobs = [] - for jobid, shared_data in model_tasks: - #self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data) + for jobid, shared_data in tasks: + #self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data) jobs.append(jobid) data.append(stack.enter_context(ndarray_shared(shared_data))) data = np.concatenate(data) - result = loaded_model.predict(data, verbose=0) + result = self.model.predict(data, verbose=0) results = np.split(result, len(jobs)) #self.logger.debug("sharing result array for '%d'", jobid) with ExitStack() as stack: @@ -180,14 +162,17 @@ class Predictor(mp.context.SpawnProcess): self.resultq.put((jobid, result)) #self.logger.debug("sent result for '%d': %s", jobid, result) except Exception as e: - self.logger.error("prediction failed: %s", e.__class__.__name__) + self.logger.error("prediction for %s failed: %s", self.name, e.__class__.__name__) result = e self.resultq.put((jobid, result)) close_all() #self.logger.debug("predictor terminated") - def load_models(self, *loadable: List[str]): - self.loadable = loadable + def load_model(self, *load_args, **load_kwargs): + assert len(load_args) + self.name = '_'.join(list(load_args[:1]) + list(load_kwargs)) + self.load_args = load_args + self.load_kwargs = load_kwargs self.start() # call run() in subprocess # parent context here del self.model_zoo # only in subprocess @@ -200,20 +185,20 @@ class Predictor(mp.context.SpawnProcess): def setup(self): logging.root.handlers = [logging.handlers.QueueHandler(self.logq)] self.logger.setLevel(self.loglevel) - self.model_zoo.load_models(*self.loadable) + self.model = self.model_zoo.load_model(*self.load_args, **self.load_kwargs) def shutdown(self): # do not terminate from forked processor instances if mp.parent_process() is None: self.stopped.set() - self.terminate() - self.logq.close() self.taskq.close() self.taskq.cancel_join_thread() self.resultq.close() self.resultq.cancel_join_thread() + self.logq.close() + self.terminate() else: - self.model_zoo.shutdown() + del self.model def __del__(self): #self.logger.debug(f"deinit of {self} in {mp.current_process().name}")