mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-14 19:31:57 +02:00
model_zoo/predictor: use one subprocess per model…
- Eynollah: instead of one `Predictor` instance as stand-in for entire `ModelZoo`, keep the latter but have each model in `_loaded` dict become an independent predictor instance - `ModelZoo.load_models()`: instantiate `Predictor`s for each `model_category` and then call `Predictor.load_model()` on them - `Predictor.load_model()`: set args/kwargs for `ModelZoo.load_model()`, then spawn subprocess via `.start()`, which first enters `setup()`... - `Predictor.setup()`: call `ModelZoo.load_model()` instead of (plural) `.load_models()`; save to `self.model` instead of `self.model_zoo` - `ModelZoo.load_model()`: move _all_ CUDA configuration and TF/Keras-specific module initialization here (to be used only by predictor subprocess) - `Predictor`: drop stand-in `SingleModelPredictor` retrieved by `get()`; directly provide `predict()` and `output_shape` via `self.call()` - `Predictor`: drop `model` arg from all queues - now implicit; use `self.name` for model name in messages - `Predictor`: no need for requeuing other tasks (only same model now) - `Predictor`: reduce rebatching batch sizes due to increased VRAM footprint - `Eynollah.setup_models()`: set up loading `_patched` / `_resized` here instead of during `ModelZoo.load_model()` - `ModelZoo.load_models()`: for resized/patched models, call `Predictor.load_model()` with kwarg instead of resp. model name suffix - `ModelZoo.load_model()`: expect boolean kwargs `patched/resized` for `wrap_layout_model_patched/resized` model wrappers, respectively
This commit is contained in:
parent
c514bbc661
commit
f54deff452
3 changed files with 131 additions and 122 deletions
|
|
@ -44,7 +44,6 @@ except ImportError:
|
|||
plt = None
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .predictor import Predictor
|
||||
from .utils.contour import (
|
||||
filter_contours_area_of_image,
|
||||
filter_contours_area_of_image_tables,
|
||||
|
|
@ -142,7 +141,7 @@ class Eynollah:
|
|||
logger : Optional[logging.Logger] = None,
|
||||
):
|
||||
self.logger = logger or logging.getLogger('eynollah')
|
||||
self.model_zoo = Predictor(self.logger, model_zoo)
|
||||
self.model_zoo = model_zoo
|
||||
self.plotter = None
|
||||
|
||||
self.reading_order_machine_based = reading_order_machine_based
|
||||
|
|
@ -174,26 +173,34 @@ class Eynollah:
|
|||
# load models, depending on modes
|
||||
# (note: loading too many models can cause OOM on GPU/CUDA,
|
||||
# thus, we try set up the minimal configuration for the current mode)
|
||||
# autosized variants: _resized or _patched (which one may depend on num_cols)
|
||||
# (but _resized for full page images is too slow - better resize on CPU in numpy)
|
||||
loadable = [
|
||||
"col_classifier",
|
||||
"binarization",
|
||||
#"enhancement",
|
||||
#"enhancement", # todo: enhancement_patched
|
||||
"page",
|
||||
#"region"
|
||||
]
|
||||
loadable.append(("textline"))
|
||||
if self.input_binary:
|
||||
loadable.append("binarization") # todo: binarization_patched
|
||||
loadable.append("textline_patched") # textline
|
||||
loadable.append("region_1_2")
|
||||
loadable.append("region_1_2_patched")
|
||||
if self.full_layout:
|
||||
loadable.append("region_fl_np")
|
||||
#loadable.append("region_fl")
|
||||
#loadable.append("region_fl_patched")
|
||||
if self.reading_order_machine_based:
|
||||
loadable.append("reading_order")
|
||||
loadable.append("reading_order") # todo: reading_order_patched
|
||||
if self.tables:
|
||||
loadable.append("table")
|
||||
|
||||
self.model_zoo.load_models(*loadable)
|
||||
for model in loadable:
|
||||
# retrieve and cache output shapes
|
||||
if model.endswith(('_resized', '_patched')):
|
||||
# autosized models do not have a predefined output_shape
|
||||
# (and don't need one)
|
||||
continue
|
||||
self.logger.debug("model %s has output shape %s", model,
|
||||
self.model_zoo.get(model).output_shape)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,21 +5,9 @@ from copy import deepcopy
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
from tensorflow.keras.models import Model as KerasModel
|
||||
from tensorflow.keras.models import load_model
|
||||
from tabulate import tabulate
|
||||
|
||||
from ..patch_encoder import (
|
||||
PatchEncoder,
|
||||
Patches,
|
||||
wrap_layout_model_patched,
|
||||
wrap_layout_model_resized,
|
||||
)
|
||||
from ..predictor import Predictor
|
||||
from .specs import EynollahModelSpecSet
|
||||
from .default_specs import DEFAULT_MODEL_SPECS
|
||||
from .types import AnyModel, T
|
||||
|
|
@ -46,7 +34,7 @@ class EynollahModelZoo:
|
|||
self._overrides = []
|
||||
if model_overrides:
|
||||
self.override_models(*model_overrides)
|
||||
self._loaded: Dict[str, AnyModel] = {}
|
||||
self._loaded: Dict[str, Predictor] = {}
|
||||
|
||||
@property
|
||||
def model_overrides(self):
|
||||
|
|
@ -90,34 +78,60 @@ class EynollahModelZoo:
|
|||
"""
|
||||
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
|
||||
"""
|
||||
import tensorflow as tf
|
||||
cuda = False
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
cuda = True
|
||||
self.logger.info("using GPU %s", device.name)
|
||||
except RuntimeError:
|
||||
self.logger.exception("cannot configure GPU devices")
|
||||
if not cuda:
|
||||
self.logger.warning("no GPU device available")
|
||||
ret = {}
|
||||
ret = {} # cannot use self._loaded here, yet – first spawn all predictors
|
||||
for load_args in all_load_args:
|
||||
if isinstance(load_args, str):
|
||||
ret[load_args] = self.load_model(load_args)
|
||||
model_category = load_args
|
||||
load_args = [model_category]
|
||||
else:
|
||||
ret[load_args[0]] = self.load_model(*load_args)
|
||||
return ret
|
||||
model_category = load_args[0]
|
||||
load_kwargs = {}
|
||||
if model_category.endswith('_resized'):
|
||||
load_args[0] = model_category[:-8]
|
||||
load_kwargs["resized"] = True
|
||||
elif model_category.endswith('_patched'):
|
||||
load_args[0] = model_category[:-8]
|
||||
load_kwargs["patched"] = True
|
||||
ret[model_category] = Predictor(self.logger, self)
|
||||
ret[model_category].load_model(*load_args, **load_kwargs)
|
||||
self._loaded.update(ret)
|
||||
return self._loaded
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_category: str,
|
||||
model_variant: str = '',
|
||||
model_path_override: Optional[str] = None,
|
||||
patched: bool = False,
|
||||
resized: bool = False,
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Load any model
|
||||
"""
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
from ..patch_encoder import (
|
||||
PatchEncoder,
|
||||
Patches,
|
||||
wrap_layout_model_patched,
|
||||
wrap_layout_model_resized,
|
||||
)
|
||||
cuda = False
|
||||
try:
|
||||
device = tf.config.list_physical_devices('GPU')[0]
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
cuda = True
|
||||
self.logger.info("using GPU %s", device.name)
|
||||
except RuntimeError:
|
||||
self.logger.exception("cannot configure GPU devices")
|
||||
if not cuda:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
if model_path_override:
|
||||
self.override_models((model_category, model_variant, model_path_override))
|
||||
model_path = self.model_path(model_category, model_variant)
|
||||
|
|
@ -142,26 +156,26 @@ class EynollahModelZoo:
|
|||
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
|
||||
)
|
||||
model._name = model_category
|
||||
self._loaded[model_category] = model
|
||||
# autosized for full page images is too slow (better resize on CPU in numpy):
|
||||
# if model_category in ['region_1_2', 'table', 'region_fl_np']:
|
||||
# self._loaded[model_category + '_resized'] = wrap_layout_model_resized(model)
|
||||
if model_category in ['region_1_2', 'textline']:
|
||||
self._loaded[model_category + '_patched'] = wrap_layout_model_patched(model)
|
||||
return model # type: ignore
|
||||
if resized:
|
||||
model = wrap_layout_model_resized(model)
|
||||
model._name = model_category + '_resized'
|
||||
elif patched:
|
||||
model = wrap_layout_model_patched(model)
|
||||
model._name = model_category + '_patched'
|
||||
return model
|
||||
|
||||
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T:
|
||||
def get(self, model_category: str) -> Predictor:
|
||||
if model_category not in self._loaded:
|
||||
raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"')
|
||||
ret = self._loaded[model_category]
|
||||
if model_type:
|
||||
assert isinstance(ret, model_type)
|
||||
return ret # type: ignore # FIXME: convince typing that we're returning generic type
|
||||
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
||||
return self._loaded[model_category]
|
||||
|
||||
def _load_ocr_model(self, variant: str) -> AnyModel:
|
||||
"""
|
||||
Load OCR model
|
||||
"""
|
||||
from tensorflow.keras.models import Model as KerasModel
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
ocr_model_dir = self.model_path('ocr', variant)
|
||||
if variant == 'tr':
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
|
|
@ -183,10 +197,12 @@ class EynollahModelZoo:
|
|||
with open(self.model_path('num_to_char'), "r") as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def _load_num_to_char(self) -> StringLookup:
|
||||
def _load_num_to_char(self) -> 'StringLookup':
|
||||
"""
|
||||
Load decoder for OCR
|
||||
"""
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
|
||||
characters = self._load_characters()
|
||||
# Mapping characters to integers.
|
||||
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
|
||||
|
|
@ -225,4 +241,5 @@ class EynollahModelZoo:
|
|||
"""
|
||||
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||||
for needle in list(self._loaded.keys()):
|
||||
self._loaded[needle].shutdown()
|
||||
del self._loaded[needle]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import threading
|
||||
from contextlib import ExitStack
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
import logging
|
||||
import logging.handlers
|
||||
import multiprocessing as mp
|
||||
|
|
@ -16,27 +14,12 @@ class Predictor(mp.context.SpawnProcess):
|
|||
"""
|
||||
singleton subprocess solely responsible for prediction with TensorFlow,
|
||||
communicates with any number of worker processes,
|
||||
acts as a shallow replacement for EynollahModelZoo
|
||||
acting as a shallow replacement for various model types in EynollahModelZoo's
|
||||
_loaded dict for each single model
|
||||
"""
|
||||
class SingleModelPredictor:
|
||||
"""
|
||||
acts as a shallow replacement for EynollahModelZoo
|
||||
"""
|
||||
def __init__(self, predictor: 'Predictor', model: str):
|
||||
self.predictor = predictor
|
||||
self.model = model
|
||||
@property
|
||||
def name(self):
|
||||
return self.model
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.predictor(self.model, {})
|
||||
def predict(self, data: dict, verbose=0):
|
||||
return self.predictor(self.model, data)
|
||||
|
||||
def __init__(self, logger, model_zoo):
|
||||
self.logger = logger
|
||||
self.loglevel = logger.level
|
||||
self.loglevel = logger.parent.level
|
||||
self.model_zoo = model_zoo
|
||||
ctxt = mp.get_context('spawn')
|
||||
self.taskq = ctxt.Queue(maxsize=QSIZE)
|
||||
|
|
@ -47,17 +30,20 @@ class Predictor(mp.context.SpawnProcess):
|
|||
# as per ocrd_utils.initLogging():
|
||||
logging.root.handlers +
|
||||
# as per eynollah_cli.main():
|
||||
self.logger.handlers
|
||||
self.logger.parent.handlers
|
||||
), respect_handler_level=False).start()
|
||||
self.stopped = ctxt.Event()
|
||||
self.closable = ctxt.Manager().list()
|
||||
super().__init__(name="EynollahPredictor", daemon=True)
|
||||
|
||||
@lru_cache
|
||||
def get(self, model: str):
|
||||
return Predictor.SingleModelPredictor(self, model)
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self({})
|
||||
|
||||
def __call__(self, model: str, data: dict):
|
||||
def predict(self, data: dict, verbose=0):
|
||||
return self(data)
|
||||
|
||||
def __call__(self, data: dict):
|
||||
# unusable as per python/cpython#79967
|
||||
#with self.jobid.get_lock():
|
||||
# would work, but not public:
|
||||
|
|
@ -66,12 +52,12 @@ class Predictor(mp.context.SpawnProcess):
|
|||
self.jobid.value += 1
|
||||
jobid = self.jobid.value
|
||||
if not len(data):
|
||||
self.taskq.put((jobid, model, data))
|
||||
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, model)
|
||||
self.taskq.put((jobid, data))
|
||||
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
|
||||
return self.result(jobid)
|
||||
with share_ndarray(data) as shared_data:
|
||||
self.taskq.put((jobid, model, shared_data))
|
||||
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, model, shared_data)
|
||||
self.taskq.put((jobid, shared_data))
|
||||
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
|
||||
return self.result(jobid)
|
||||
|
||||
def result(self, jobid):
|
||||
|
|
@ -80,7 +66,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
#self.logger.debug("received result for '%d'", jobid)
|
||||
result = self.results.pop(jobid)
|
||||
if isinstance(result, Exception):
|
||||
raise Exception(f"predictor failed for {jobid}") from result
|
||||
raise Exception(f"predictor {self.name} failed for {jobid}") from result
|
||||
elif isinstance(result, dict):
|
||||
with ndarray_shared(result) as shared_result:
|
||||
result = np.copy(shared_result)
|
||||
|
|
@ -92,7 +78,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
continue
|
||||
#self.logger.debug("storing results for '%d': '%s'", jobid0, result)
|
||||
self.results[jobid0] = result
|
||||
raise Exception(f"predictor terminated while waiting on results for {jobid}")
|
||||
raise Exception(f"predictor {self.name} terminated while waiting on results for {jobid}")
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
|
|
@ -100,6 +86,7 @@ class Predictor(mp.context.SpawnProcess):
|
|||
except Exception as e:
|
||||
self.logger.exception("setup failed")
|
||||
self.stopped.set()
|
||||
return
|
||||
closing = {}
|
||||
def close_all():
|
||||
for jobid in list(self.closable):
|
||||
|
|
@ -110,63 +97,58 @@ class Predictor(mp.context.SpawnProcess):
|
|||
close_all()
|
||||
try:
|
||||
TIMEOUT = 4.5 # 1.1 too is greedy - not enough for rebatching
|
||||
jobid, model, shared_data = self.taskq.get(timeout=TIMEOUT)
|
||||
jobid, shared_data = self.taskq.get(timeout=TIMEOUT)
|
||||
except mp.queues.Empty:
|
||||
continue
|
||||
try:
|
||||
# up to what batch size fits into small (8GB) VRAM?
|
||||
# (notice we are not listing _resized/_patched models here,
|
||||
# because here inputs/outputs will have varying shapes)
|
||||
# because its inputs/outputs will have varying shapes)
|
||||
REBATCH_SIZE = {
|
||||
# small models (448x448)...
|
||||
"col_classifier": 4,
|
||||
"page": 4,
|
||||
"binarization": 5,
|
||||
"enhancement": 5,
|
||||
"reading_order": 5,
|
||||
"col_classifier": 2,
|
||||
"page": 2,
|
||||
"binarization": 4,
|
||||
"enhancement": 4,
|
||||
"reading_order": 4,
|
||||
# medium size (672x672)...
|
||||
"textline": 3,
|
||||
"textline": 2,
|
||||
# large models...
|
||||
"table": 2,
|
||||
"region_1_2": 2,
|
||||
"region_fl_np": 2,
|
||||
"region_fl": 2,
|
||||
}.get(model, 1)
|
||||
loaded_model = self.model_zoo.get(model)
|
||||
"table": 1,
|
||||
"region_1_2": 1,
|
||||
"region_fl_np": 1,
|
||||
"region_fl": 1,
|
||||
}.get(self.name, 1)
|
||||
if not len(shared_data):
|
||||
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, model)
|
||||
result = loaded_model.output_shape
|
||||
#self.logger.debug("getting '%d' output shape of model '%s'", jobid, self.name)
|
||||
result = self.model.output_shape
|
||||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
else:
|
||||
other_tasks = [] # other model, put back on q
|
||||
model_tasks = [] # same model, for rebatching
|
||||
model_tasks.append((jobid, shared_data))
|
||||
tasks = [(jobid, shared_data)]
|
||||
batch_size = shared_data['shape'][0]
|
||||
while (not self.taskq.empty() and
|
||||
# climb to target batch size
|
||||
batch_size * len(model_tasks) < REBATCH_SIZE):
|
||||
jobid0, model0, shared_data0 = self.taskq.get()
|
||||
if model0 == model and len(shared_data0):
|
||||
batch_size * len(tasks) < REBATCH_SIZE):
|
||||
jobid0, shared_data0 = self.taskq.get()
|
||||
if len(shared_data0):
|
||||
# add to our batch
|
||||
model_tasks.append((jobid0, shared_data0))
|
||||
tasks.append((jobid0, shared_data0))
|
||||
else:
|
||||
other_tasks.append((jobid0, model0, shared_data0))
|
||||
if len(other_tasks):
|
||||
self.logger.debug("requeuing %d other tasks", len(other_tasks))
|
||||
for task in other_tasks:
|
||||
self.taskq.put(task)
|
||||
if len(model_tasks) > 1:
|
||||
self.logger.debug("rebatching %d %s tasks of batch size %d", len(model_tasks), model, batch_size)
|
||||
# immediately anser
|
||||
self.resultq.put((jobid0, self.model.output_shape))
|
||||
if len(tasks) > 1:
|
||||
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
|
||||
len(tasks), self.name, batch_size)
|
||||
with ExitStack() as stack:
|
||||
data = []
|
||||
jobs = []
|
||||
for jobid, shared_data in model_tasks:
|
||||
#self.logger.debug("predicting '%d' with model '%s': ", jobid, model, shared_data)
|
||||
for jobid, shared_data in tasks:
|
||||
#self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
|
||||
jobs.append(jobid)
|
||||
data.append(stack.enter_context(ndarray_shared(shared_data)))
|
||||
data = np.concatenate(data)
|
||||
result = loaded_model.predict(data, verbose=0)
|
||||
result = self.model.predict(data, verbose=0)
|
||||
results = np.split(result, len(jobs))
|
||||
#self.logger.debug("sharing result array for '%d'", jobid)
|
||||
with ExitStack() as stack:
|
||||
|
|
@ -180,14 +162,17 @@ class Predictor(mp.context.SpawnProcess):
|
|||
self.resultq.put((jobid, result))
|
||||
#self.logger.debug("sent result for '%d': %s", jobid, result)
|
||||
except Exception as e:
|
||||
self.logger.error("prediction failed: %s", e.__class__.__name__)
|
||||
self.logger.error("prediction for %s failed: %s", self.name, e.__class__.__name__)
|
||||
result = e
|
||||
self.resultq.put((jobid, result))
|
||||
close_all()
|
||||
#self.logger.debug("predictor terminated")
|
||||
|
||||
def load_models(self, *loadable: List[str]):
|
||||
self.loadable = loadable
|
||||
def load_model(self, *load_args, **load_kwargs):
|
||||
assert len(load_args)
|
||||
self.name = '_'.join(list(load_args[:1]) + list(load_kwargs))
|
||||
self.load_args = load_args
|
||||
self.load_kwargs = load_kwargs
|
||||
self.start() # call run() in subprocess
|
||||
# parent context here
|
||||
del self.model_zoo # only in subprocess
|
||||
|
|
@ -200,20 +185,20 @@ class Predictor(mp.context.SpawnProcess):
|
|||
def setup(self):
|
||||
logging.root.handlers = [logging.handlers.QueueHandler(self.logq)]
|
||||
self.logger.setLevel(self.loglevel)
|
||||
self.model_zoo.load_models(*self.loadable)
|
||||
self.model = self.model_zoo.load_model(*self.load_args, **self.load_kwargs)
|
||||
|
||||
def shutdown(self):
|
||||
# do not terminate from forked processor instances
|
||||
if mp.parent_process() is None:
|
||||
self.stopped.set()
|
||||
self.terminate()
|
||||
self.logq.close()
|
||||
self.taskq.close()
|
||||
self.taskq.cancel_join_thread()
|
||||
self.resultq.close()
|
||||
self.resultq.cancel_join_thread()
|
||||
self.logq.close()
|
||||
self.terminate()
|
||||
else:
|
||||
self.model_zoo.shutdown()
|
||||
del self.model
|
||||
|
||||
def __del__(self):
|
||||
#self.logger.debug(f"deinit of {self} in {mp.current_process().name}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue