model_zoo/predictor: use one subprocess per model…

- Eynollah: instead of one `Predictor` instance as stand-in for
  entire `ModelZoo`, keep the latter but have each model in `_loaded`
  dict become an independent predictor instance
- `ModelZoo.load_models()`: instantiate `Predictor`s for each
  `model_category` and then call `Predictor.load_model()` on them
- `Predictor.load_model()`: set args/kwargs for `ModelZoo.load_model()`,
  then spawn subprocess via `.start()`, which first enters `setup()`...
- `Predictor.setup()`: call `ModelZoo.load_model()` instead of (plural)
 `.load_models()`; save to `self.model` instead of `self.model_zoo`
- `ModelZoo.load_model()`: move _all_ CUDA configuration and
  TF/Keras-specific module initialization here (to be used only by
  predictor subprocess)
- `Predictor`: drop stand-in `SingleModelPredictor` retrieved by `get()`;
  directly provide `predict()` and `output_shape` via `self.call()`
- `Predictor`: drop `model` arg from all queues - now implicit; use
  `self.name` for model name in messages
- `Predictor`: no need for requeuing other tasks (only same model now)
- `Predictor`: reduce rebatching batch sizes due to increased VRAM footprint

- `Eynollah.setup_models()`: set up loading `_patched` / `_resized`
  here instead of during `ModelZoo.load_model()`
- `ModelZoo.load_models()`: for resized/patched models, call
  `Predictor.load_model()` with kwarg instead of resp. model name suffix
- `ModelZoo.load_model()`: expect boolean kwargs `patched/resized`
  for `wrap_layout_model_patched/resized` model wrappers, respectively
This commit is contained in:
Robert Sachunsky 2026-03-15 02:53:37 +01:00
parent c514bbc661
commit f54deff452
3 changed files with 131 additions and 122 deletions

View file

@ -44,7 +44,6 @@ except ImportError:
plt = None plt = None
from .model_zoo import EynollahModelZoo from .model_zoo import EynollahModelZoo
from .predictor import Predictor
from .utils.contour import ( from .utils.contour import (
filter_contours_area_of_image, filter_contours_area_of_image,
filter_contours_area_of_image_tables, filter_contours_area_of_image_tables,
@ -142,7 +141,7 @@ class Eynollah:
logger : Optional[logging.Logger] = None, logger : Optional[logging.Logger] = None,
): ):
self.logger = logger or logging.getLogger('eynollah') self.logger = logger or logging.getLogger('eynollah')
self.model_zoo = Predictor(self.logger, model_zoo) self.model_zoo = model_zoo
self.plotter = None self.plotter = None
self.reading_order_machine_based = reading_order_machine_based self.reading_order_machine_based = reading_order_machine_based
@ -174,26 +173,34 @@ class Eynollah:
# load models, depending on modes # load models, depending on modes
# (note: loading too many models can cause OOM on GPU/CUDA, # (note: loading too many models can cause OOM on GPU/CUDA,
# thus, we try set up the minimal configuration for the current mode) # thus, we try set up the minimal configuration for the current mode)
# autosized variants: _resized or _patched (which one may depend on num_cols)
# (but _resized for full page images is too slow - better resize on CPU in numpy)
loadable = [ loadable = [
"col_classifier", "col_classifier",
"binarization", #"enhancement", # todo: enhancement_patched
#"enhancement",
"page", "page",
#"region" #"region"
] ]
loadable.append(("textline")) if self.input_binary:
loadable.append("binarization") # todo: binarization_patched
loadable.append("textline_patched") # textline
loadable.append("region_1_2") loadable.append("region_1_2")
loadable.append("region_1_2_patched")
if self.full_layout: if self.full_layout:
loadable.append("region_fl_np") loadable.append("region_fl_np")
#loadable.append("region_fl") #loadable.append("region_fl_patched")
if self.reading_order_machine_based: if self.reading_order_machine_based:
loadable.append("reading_order") loadable.append("reading_order") # todo: reading_order_patched
if self.tables: if self.tables:
loadable.append("table") loadable.append("table")
self.model_zoo.load_models(*loadable) self.model_zoo.load_models(*loadable)
for model in loadable: for model in loadable:
# retrieve and cache output shapes # retrieve and cache output shapes
if model.endswith(('_resized', '_patched')):
# autosized models do not have a predefined output_shape
# (and don't need one)
continue
self.logger.debug("model %s has output shape %s", model, self.logger.debug("model %s has output shape %s", model,
self.model_zoo.get(model).output_shape) self.model_zoo.get(model).output_shape)

View file

@ -5,21 +5,9 @@ from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
from tensorflow.keras.layers import StringLookup
from tensorflow.keras.models import Model as KerasModel
from tensorflow.keras.models import load_model
from tabulate import tabulate from tabulate import tabulate
from ..patch_encoder import ( from ..predictor import Predictor
PatchEncoder,
Patches,
wrap_layout_model_patched,
wrap_layout_model_resized,
)
from .specs import EynollahModelSpecSet from .specs import EynollahModelSpecSet
from .default_specs import DEFAULT_MODEL_SPECS from .default_specs import DEFAULT_MODEL_SPECS
from .types import AnyModel, T from .types import AnyModel, T
@ -46,7 +34,7 @@ class EynollahModelZoo:
self._overrides = [] self._overrides = []
if model_overrides: if model_overrides:
self.override_models(*model_overrides) self.override_models(*model_overrides)
self._loaded: Dict[str, AnyModel] = {} self._loaded: Dict[str, Predictor] = {}
@property @property
def model_overrides(self): def model_overrides(self):
@ -90,34 +78,60 @@ class EynollahModelZoo:
""" """
Load all models by calling load_model and return a dictionary mapping model_category to loaded model Load all models by calling load_model and return a dictionary mapping model_category to loaded model
""" """
import tensorflow as tf ret = {} # cannot use self._loaded here, yet first spawn all predictors
cuda = False
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
cuda = True
self.logger.info("using GPU %s", device.name)
except RuntimeError:
self.logger.exception("cannot configure GPU devices")
if not cuda:
self.logger.warning("no GPU device available")
ret = {}
for load_args in all_load_args: for load_args in all_load_args:
if isinstance(load_args, str): if isinstance(load_args, str):
ret[load_args] = self.load_model(load_args) model_category = load_args
load_args = [model_category]
else: else:
ret[load_args[0]] = self.load_model(*load_args) model_category = load_args[0]
return ret load_kwargs = {}
if model_category.endswith('_resized'):
load_args[0] = model_category[:-8]
load_kwargs["resized"] = True
elif model_category.endswith('_patched'):
load_args[0] = model_category[:-8]
load_kwargs["patched"] = True
ret[model_category] = Predictor(self.logger, self)
ret[model_category].load_model(*load_args, **load_kwargs)
self._loaded.update(ret)
return self._loaded
def load_model( def load_model(
self, self,
model_category: str, model_category: str,
model_variant: str = '', model_variant: str = '',
model_path_override: Optional[str] = None, model_path_override: Optional[str] = None,
patched: bool = False,
resized: bool = False,
) -> AnyModel: ) -> AnyModel:
""" """
Load any model Load any model
""" """
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
from tensorflow.keras.models import load_model
from ..patch_encoder import (
PatchEncoder,
Patches,
wrap_layout_model_patched,
wrap_layout_model_resized,
)
cuda = False
try:
device = tf.config.list_physical_devices('GPU')[0]
tf.config.experimental.set_memory_growth(device, True)
cuda = True
self.logger.info("using GPU %s", device.name)
except RuntimeError:
self.logger.exception("cannot configure GPU devices")
if not cuda:
self.logger.warning("no GPU device available")
if model_path_override: if model_path_override:
self.override_models((model_category, model_variant, model_path_override)) self.override_models((model_category, model_variant, model_path_override))
model_path = self.model_path(model_category, model_variant) model_path = self.model_path(model_category, model_variant)
@ -142,26 +156,26 @@ class EynollahModelZoo:
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches} model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
) )
model._name = model_category model._name = model_category
self._loaded[model_category] = model if resized:
# autosized for full page images is too slow (better resize on CPU in numpy): model = wrap_layout_model_resized(model)
# if model_category in ['region_1_2', 'table', 'region_fl_np']: model._name = model_category + '_resized'
# self._loaded[model_category + '_resized'] = wrap_layout_model_resized(model) elif patched:
if model_category in ['region_1_2', 'textline']: model = wrap_layout_model_patched(model)
self._loaded[model_category + '_patched'] = wrap_layout_model_patched(model) model._name = model_category + '_patched'
return model # type: ignore return model
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T: def get(self, model_category: str) -> Predictor:
if model_category not in self._loaded: if model_category not in self._loaded:
raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"') raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
ret = self._loaded[model_category] return self._loaded[model_category]
if model_type:
assert isinstance(ret, model_type)
return ret # type: ignore # FIXME: convince typing that we're returning generic type
def _load_ocr_model(self, variant: str) -> AnyModel: def _load_ocr_model(self, variant: str) -> AnyModel:
""" """
Load OCR model Load OCR model
""" """
from tensorflow.keras.models import Model as KerasModel
from tensorflow.keras.models import load_model
ocr_model_dir = self.model_path('ocr', variant) ocr_model_dir = self.model_path('ocr', variant)
if variant == 'tr': if variant == 'tr':
from transformers import VisionEncoderDecoderModel from transformers import VisionEncoderDecoderModel
@ -183,10 +197,12 @@ class EynollahModelZoo:
with open(self.model_path('num_to_char'), "r") as config_file: with open(self.model_path('num_to_char'), "r") as config_file:
return json.load(config_file) return json.load(config_file)
def _load_num_to_char(self) -> StringLookup: def _load_num_to_char(self) -> 'StringLookup':
""" """
Load decoder for OCR Load decoder for OCR
""" """
from tensorflow.keras.layers import StringLookup
characters = self._load_characters() characters = self._load_characters()
# Mapping characters to integers. # Mapping characters to integers.
char_to_num = StringLookup(vocabulary=characters, mask_token=None) char_to_num = StringLookup(vocabulary=characters, mask_token=None)
@ -225,4 +241,5 @@ class EynollahModelZoo:
""" """
if hasattr(self, '_loaded') and getattr(self, '_loaded'): if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in list(self._loaded.keys()): for needle in list(self._loaded.keys()):
self._loaded[needle].shutdown()
del self._loaded[needle] del self._loaded[needle]

View file

@ -1,7 +1,5 @@
import threading
from contextlib import ExitStack from contextlib import ExitStack
from functools import lru_cache from typing import List, Dict
from typing import List
import logging import logging
import logging.handlers import logging.handlers
import multiprocessing as mp import multiprocessing as mp
@ -16,27 +14,12 @@ class Predictor(mp.context.SpawnProcess):
""" """
singleton subprocess solely responsible for prediction with TensorFlow, singleton subprocess solely responsible for prediction with TensorFlow,
communicates with any number of worker processes, communicates with any number of worker processes,
acts as a shallow replacement for EynollahModelZoo acting as a shallow replacement for various model types in EynollahModelZoo's
_loaded dict for each single model
""" """
class SingleModelPredictor:
"""
acts as a shallow replacement for EynollahModelZoo
"""
def __init__(self, predictor: 'Predictor', model: str):
self.predictor = predictor
self.model = model
@property
def name(self):
return self.model
@property
def output_shape(self):
return self.predictor(self.model, {})
def predict(self, data: dict, verbose=0):
return self.predictor(self.model, data)
def __init__(self, logger, model_zoo): def __init__(self, logger, model_zoo):
self.logger = logger self.logger = logger
self.loglevel = logger.level self.loglevel = logger.parent.level
self.model_zoo = model_zoo self.model_zoo = model_zoo
ctxt = mp.get_context('spawn') ctxt = mp.get_context('spawn')
self.taskq = ctxt.Queue(maxsize=QSIZE) self.taskq = ctxt.Queue(maxsize=QSIZE)
@ -47,17 +30,20 @@ class Predictor(mp.context.SpawnProcess):
# as per ocrd_utils.initLogging(): # as per ocrd_utils.initLogging():
logging.root.handlers + logging.root.handlers +
# as per eynollah_cli.main(): # as per eynollah_cli.main():
self.logger.handlers self.logger.parent.handlers
), respect_handler_level=False).start() ), respect_handler_level=False).start()
self.stopped = ctxt.Event() self.stopped = ctxt.Event()
self.closable = ctxt.Manager().list() self.closable = ctxt.Manager().list()
super().__init__(name="EynollahPredictor", daemon=True) super().__init__(name="EynollahPredictor", daemon=True)
@lru_cache @property
def get(self, model: str): def output_shape(self):
return Predictor.SingleModelPredictor(self, model) return self({})
def __call__(self, model: str, data: dict): def predict(self, data: dict, verbose=0):
return self(data)
def __call__(self, data: dict):
# unusable as per python/cpython#79967 # unusable as per python/cpython#79967
#with self.jobid.get_lock(): #with self.jobid.get_lock():
# would work, but not public: # would work, but not public:
@ -66,12 +52,12 @@ class Predictor(mp.context.SpawnProcess):
self.jobid.value += 1 self.jobid.value += 1
jobid = self.jobid.value jobid = self.jobid.value
if not len(data): if not len(data):
self.taskq.put((jobid, model, data)) self.taskq.put((jobid, data))
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, model) #self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
return self.result(jobid) return self.result(jobid)
with share_ndarray(data) as shared_data: with share_ndarray(data) as shared_data:
self.taskq.put((jobid, model, shared_data)) self.taskq.put((jobid, shared_data))
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, model, shared_data) #self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
return self.result(jobid) return self.result(jobid)
def result(self, jobid): def result(self, jobid):
@ -80,7 +66,7 @@ class Predictor(mp.context.SpawnProcess):
#self.logger.debug("received result for '%d'", jobid) #self.logger.debug("received result for '%d'", jobid)
result = self.results.pop(jobid) result = self.results.pop(jobid)
if isinstance(result, Exception): if isinstance(result, Exception):
raise Exception(f"predictor failed for {jobid}") from result raise Exception(f"predictor {self.name} failed for {jobid}") from result
elif isinstance(result, dict): elif isinstance(result, dict):
with ndarray_shared(result) as shared_result: with ndarray_shared(result) as shared_result:
result = np.copy(shared_result) result = np.copy(shared_result)
@ -92,7 +78,7 @@ class Predictor(mp.context.SpawnProcess):
continue continue
#self.logger.debug("storing results for '%d': '%s'", jobid0, result) #self.logger.debug("storing results for '%d': '%s'", jobid0, result)
self.results[jobid0] = result self.results[jobid0] = result
raise Exception(f"predictor terminated while waiting on results for {jobid}") raise Exception(f"predictor {self.name} terminated while waiting on results for {jobid}")
def run(self): def run(self):
try: try:
@ -100,6 +86,7 @@ class Predictor(mp.context.SpawnProcess):
except Exception as e: except Exception as e:
self.logger.exception("setup failed") self.logger.exception("setup failed")
self.stopped.set() self.stopped.set()
return
closing = {} closing = {}
def close_all(): def close_all():
for jobid in list(self.closable): for jobid in list(self.closable):
@ -110,63 +97,58 @@ class Predictor(mp.context.SpawnProcess):
close_all() close_all()
try: try:
TIMEOUT = 4.5 # 1.1 too is greedy - not enough for rebatching TIMEOUT = 4.5 # 1.1 too is greedy - not enough for rebatching
jobid, model, shared_data = self.taskq.get(timeout=TIMEOUT) jobid, shared_data = self.taskq.get(timeout=TIMEOUT)
except mp.queues.Empty: except mp.queues.Empty:
continue continue
try: try:
# up to what batch size fits into small (8GB) VRAM? # up to what batch size fits into small (8GB) VRAM?
# (notice we are not listing _resized/_patched models here, # (notice we are not listing _resized/_patched models here,
# because here inputs/outputs will have varying shapes) # because its inputs/outputs will have varying shapes)
REBATCH_SIZE = { REBATCH_SIZE = {
# small models (448x448)... # small models (448x448)...
"col_classifier": 4, "col_classifier": 2,
"page": 4, "page": 2,
"binarization": 5, "binarization": 4,
"enhancement": 5, "enhancement": 4,
"reading_order": 5, "reading_order": 4,
# medium size (672x672)... # medium size (672x672)...
"textline": 3, "textline": 2,
# large models... # large models...
"table": 2, "table": 1,
"region_1_2": 2, "region_1_2": 1,
"region_fl_np": 2, "region_fl_np": 1,
"region_fl": 2, "region_fl": 1,
}.get(model, 1) }.get(self.name, 1)
loaded_model = self.model_zoo.get(model)
if not len(shared_data): if not len(shared_data):
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, model) #self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name)
result = loaded_model.output_shape result = self.model.output_shape
self.resultq.put((jobid, result)) self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result) #self.logger.debug("sent result for '%d': %s", jobid, result)
else: else:
other_tasks = [] # other model, put back on q tasks = [(jobid, shared_data)]
model_tasks = [] # same model, for rebatching
model_tasks.append((jobid, shared_data))
batch_size = shared_data['shape'][0] batch_size = shared_data['shape'][0]
while (not self.taskq.empty() and while (not self.taskq.empty() and
# climb to target batch size # climb to target batch size
batch_size * len(model_tasks) < REBATCH_SIZE): batch_size * len(tasks) < REBATCH_SIZE):
jobid0, model0, shared_data0 = self.taskq.get() jobid0, shared_data0 = self.taskq.get()
if model0 == model and len(shared_data0): if len(shared_data0):
# add to our batch # add to our batch
model_tasks.append((jobid0, shared_data0)) tasks.append((jobid0, shared_data0))
else: else:
other_tasks.append((jobid0, model0, shared_data0)) # immediately anser
if len(other_tasks): self.resultq.put((jobid0, self.model.output_shape))
self.logger.debug("requeuing %d other tasks", len(other_tasks)) if len(tasks) > 1:
for task in other_tasks: self.logger.debug("rebatching %d '%s' tasks of batch size %d",
self.taskq.put(task) len(tasks), self.name, batch_size)
if len(model_tasks) > 1:
self.logger.debug("rebatching %d %s tasks of batch size %d", len(model_tasks), model, batch_size)
with ExitStack() as stack: with ExitStack() as stack:
data = [] data = []
jobs = [] jobs = []
for jobid, shared_data in model_tasks: for jobid, shared_data in tasks:
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data) #self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
jobs.append(jobid) jobs.append(jobid)
data.append(stack.enter_context(ndarray_shared(shared_data))) data.append(stack.enter_context(ndarray_shared(shared_data)))
data = np.concatenate(data) data = np.concatenate(data)
result = loaded_model.predict(data, verbose=0) result = self.model.predict(data, verbose=0)
results = np.split(result, len(jobs)) results = np.split(result, len(jobs))
#self.logger.debug("sharing result array for '%d'", jobid) #self.logger.debug("sharing result array for '%d'", jobid)
with ExitStack() as stack: with ExitStack() as stack:
@ -180,14 +162,17 @@ class Predictor(mp.context.SpawnProcess):
self.resultq.put((jobid, result)) self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result) #self.logger.debug("sent result for '%d': %s", jobid, result)
except Exception as e: except Exception as e:
self.logger.error("prediction failed: %s", e.__class__.__name__) self.logger.error("prediction for %s failed: %s", self.name, e.__class__.__name__)
result = e result = e
self.resultq.put((jobid, result)) self.resultq.put((jobid, result))
close_all() close_all()
#self.logger.debug("predictor terminated") #self.logger.debug("predictor terminated")
def load_models(self, *loadable: List[str]): def load_model(self, *load_args, **load_kwargs):
self.loadable = loadable assert len(load_args)
self.name = '_'.join(list(load_args[:1]) + list(load_kwargs))
self.load_args = load_args
self.load_kwargs = load_kwargs
self.start() # call run() in subprocess self.start() # call run() in subprocess
# parent context here # parent context here
del self.model_zoo # only in subprocess del self.model_zoo # only in subprocess
@ -200,20 +185,20 @@ class Predictor(mp.context.SpawnProcess):
def setup(self): def setup(self):
logging.root.handlers = [logging.handlers.QueueHandler(self.logq)] logging.root.handlers = [logging.handlers.QueueHandler(self.logq)]
self.logger.setLevel(self.loglevel) self.logger.setLevel(self.loglevel)
self.model_zoo.load_models(*self.loadable) self.model = self.model_zoo.load_model(*self.load_args, **self.load_kwargs)
def shutdown(self): def shutdown(self):
# do not terminate from forked processor instances # do not terminate from forked processor instances
if mp.parent_process() is None: if mp.parent_process() is None:
self.stopped.set() self.stopped.set()
self.terminate()
self.logq.close()
self.taskq.close() self.taskq.close()
self.taskq.cancel_join_thread() self.taskq.cancel_join_thread()
self.resultq.close() self.resultq.close()
self.resultq.cancel_join_thread() self.resultq.cancel_join_thread()
self.logq.close()
self.terminate()
else: else:
self.model_zoo.shutdown() del self.model
def __del__(self): def __del__(self):
#self.logger.debug(f"deinit of {self} in {mp.current_process().name}") #self.logger.debug(f"deinit of {self} in {mp.current_process().name}")