From 04bc4a63d0365bd93b5c302abaec431219b66776 Mon Sep 17 00:00:00 2001 From: kba Date: Wed, 22 Oct 2025 16:04:48 +0200 Subject: [PATCH] reorganize model_zoo --- requirements.txt | 1 + src/eynollah/cli.py | 40 +- src/eynollah/cli_models.py | 49 +++ src/eynollah/eynollah_ocr.py | 6 +- src/eynollah/model_zoo.py | 468 ------------------------ src/eynollah/model_zoo/__init__.py | 4 + src/eynollah/model_zoo/default_specs.py | 314 ++++++++++++++++ src/eynollah/model_zoo/model_zoo.py | 189 ++++++++++ src/eynollah/model_zoo/specs.py | 55 +++ src/eynollah/model_zoo/types.py | 6 + src/eynollah/sbb_binarize.py | 2 +- 11 files changed, 627 insertions(+), 507 deletions(-) create mode 100644 src/eynollah/cli_models.py delete mode 100644 src/eynollah/model_zoo.py create mode 100644 src/eynollah/model_zoo/__init__.py create mode 100644 src/eynollah/model_zoo/default_specs.py create mode 100644 src/eynollah/model_zoo/model_zoo.py create mode 100644 src/eynollah/model_zoo/specs.py create mode 100644 src/eynollah/model_zoo/types.py diff --git a/requirements.txt b/requirements.txt index db1d7df..bbacd48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ tensorflow < 2.13 numba <= 0.58.1 scikit-image biopython +tabulate diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index c7d4bd9..595f0ee 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -11,46 +11,13 @@ from eynollah.image_enhancer import Enhancer from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout from eynollah.model_zoo import EynollahModelZoo -@dataclass -class EynollahCliCtx(): - model_basedir: str - model_overrides: List[Tuple[str, str, str]] +from .cli_models import models_cli @click.group() def main(): pass -@main.command('list-models') -@click.option( - "--model", - "-m", - 'model_basedir', - help="directory of models", - type=click.Path(exists=True, file_okay=False), - # default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment", - required=True, -) -@click.option( - "--model-overrides", - "-mv", - help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list", - type=(str, str, str), - multiple=True, -) -@click.pass_context -def list_models( - ctx, - model_basedir: str, - model_overrides: List[Tuple[str, str, str]], -): - """ - List all the models in the zoo - """ - ctx.obj = EynollahCliCtx( - model_basedir=model_basedir, - model_overrides=model_overrides - ) - print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides)) +main.add_command(models_cli, 'models') @main.command() @click.option( @@ -143,13 +110,12 @@ def binarization( log_level, ): assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - binarizer = SbbBinarizer(model_dir) + binarizer = SbbBinarizer(model_dir, mode=mode) if log_level: binarizer.log.setLevel(getLevelName(log_level)) binarizer.run( image_path=input_image, use_patches=patches, - mode=mode, output=output, dir_in=dir_in ) diff --git a/src/eynollah/cli_models.py b/src/eynollah/cli_models.py new file mode 100644 index 0000000..b67a3ef --- /dev/null +++ b/src/eynollah/cli_models.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from typing import List, Tuple +import click +from .model_zoo import EynollahModelZoo + +@dataclass() +class EynollahCliCtx(): + model_basedir: str + model_overrides: List[Tuple[str, str, str]] + + +@click.group() +def models_cli(): + """ + Organize models for the various runners in eynollah. + """ + +@models_cli.command('list') +@click.option( + "--model", + "-m", + 'model_basedir', + help="directory of models", + type=click.Path(exists=True, file_okay=False), + # default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment", + required=True, +) +@click.option( + "--model-overrides", + "-mv", + help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list", + type=(str, str, str), + multiple=True, +) +@click.pass_context +def list_models( + ctx, + model_basedir: str, + model_overrides: List[Tuple[str, str, str]], +): + """ + List all the models in the zoo + """ + ctx.obj = EynollahCliCtx( + model_basedir=model_basedir, + model_overrides=model_overrides + ) + print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides)) + diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index b021e92..cfd410c 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -18,7 +18,11 @@ from keras.models import load_model from PIL import Image, ImageDraw, ImageFont import numpy as np from eynollah.model_zoo import EynollahModelZoo -import torch +try: + import torch +except ImportError: + torch = None + from .utils import is_image_filename from .utils.resize import resize_image diff --git a/src/eynollah/model_zoo.py b/src/eynollah/model_zoo.py deleted file mode 100644 index 6bb06d3..0000000 --- a/src/eynollah/model_zoo.py +++ /dev/null @@ -1,468 +0,0 @@ -from copy import deepcopy -from dataclasses import dataclass -import json -import logging -from pathlib import Path -from typing import Dict, Optional, Set, Tuple, List, Type, TypeVar, Union - -from keras.layers import StringLookup -from keras.models import Model as KerasModel, load_model -from transformers import TrOCRProcessor, VisionEncoderDecoderModel - -from eynollah.patch_encoder import PatchEncoder, Patches - -AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List] -T = TypeVar('T') - -# NOTE: This needs to change whenever models change -ZENODO = "https://zenodo.org/records/17295988/files" -MODELS_VERSION = "v0_7_0" - -def dist_url(dist_name: str) -> str: - return f'{ZENODO}/models_{dist_name}_${MODELS_VERSION}.zip' - -@dataclass -class EynollahModelSpec(): - """ - Describing a single model abstractly. - """ - category: str - # Relative filename to the models_eynollah directory in the dists - filename: str - # The smallest model distribution containing this model (link to Zenodo) - dist: str - type: Type[AnyModel] - variant: str = '' - help: str = '' - -class EynollahModelSpecSet(): - """ - List of all used models for eynollah. - """ - specs: List[EynollahModelSpec] - - def __init__(self, specs: List[EynollahModelSpec]) -> None: - self.specs = specs - self.categories: Set[str] = set([spec.category for spec in self.specs]) - self.variants: Dict[str, Set[str]] = { - spec.category: set([x.variant for x in self.specs if x.category == spec.category]) - for spec in self.specs - } - self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = { - (spec.category, spec.variant): spec - for spec in self.specs - } - - def asdict(self) -> Dict[str, Dict[str, str]]: - return { - spec.category: { - spec.variant: spec.filename - } - for spec in self.specs - } - - def get(self, category: str, variant: str) -> EynollahModelSpec: - if category not in self.categories: - raise ValueError(f"Unknown category '{category}', must be one of {self.categories}") - if variant not in self.variants[category]: - raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}") - return self._index_category_variant[(category, variant)] - -DEFAULT_MODEL_SPECS = EynollahModelSpecSet([# {{{ - - EynollahModelSpec( - category="enhancement", - variant='', - filename="models_eynollah/eynollah-enhancement_20210425", - dist=dist_url("enhancement"), - type=KerasModel, - ), - - EynollahModelSpec( - category="binarization", - variant='', - filename="models_eynollah/eynollah-binarization_20210425", - dist=dist_url("binarization"), - type=KerasModel, - ), - - EynollahModelSpec( - category="binarization_multi_1", - variant='', - filename="models_eynollah/saved_model_2020_01_16/model_bin1", - dist=dist_url("binarization"), - type=KerasModel, - ), - - EynollahModelSpec( - category="binarization_multi_2", - variant='', - filename="models_eynollah/saved_model_2020_01_16/model_bin2", - dist=dist_url("binarization"), - type=KerasModel, - ), - - EynollahModelSpec( - category="binarization_multi_3", - variant='', - filename="models_eynollah/saved_model_2020_01_16/model_bin3", - dist=dist_url("binarization"), - type=KerasModel, - ), - - EynollahModelSpec( - category="binarization_multi_4", - variant='', - filename="models_eynollah/saved_model_2020_01_16/model_bin4", - dist=dist_url("binarization"), - type=KerasModel, - ), - - EynollahModelSpec( - category="col_classifier", - variant='', - filename="models_eynollah/eynollah-column-classifier_20210425", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="page", - variant='', - filename="models_eynollah/model_eynollah_page_extraction_20250915", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="region", - variant='', - filename="models_eynollah/eynollah-main-regions-ensembled_20210425", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="region", - variant='extract_only_images', - filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="region", - variant='light', - filename="models_eynollah/eynollah-main-regions_20220314", - dist=dist_url("layout"), - help="early layout", - type=KerasModel, - ), - - EynollahModelSpec( - category="region_p2", - variant='', - filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425", - dist=dist_url("layout"), - help="early layout, non-light, 2nd part", - type=KerasModel, - ), - - EynollahModelSpec( - category="region_1_2", - variant='', - #filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n", - #filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8", - #filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18", - #filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige", - #filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige", - filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024", - dist=dist_url("layout"), - help="early layout, light, 1-or-2-column", - type=KerasModel, - ), - - EynollahModelSpec( - category="region_fl_np", - variant='', - #'filename="models_eynollah/modelens_full_lay_1_3_031124", - #'filename="models_eynollah/modelens_full_lay_13__3_19_241024", - #'filename="models_eynollah/model_full_lay_13_241024", - #'filename="models_eynollah/modelens_full_lay_13_17_231024", - #'filename="models_eynollah/modelens_full_lay_1_2_221024", - #'filename="models_eynollah/eynollah-full-regions-1column_20210425", - filename="models_eynollah/modelens_full_lay_1__4_3_091124", - dist=dist_url("layout"), - help="full layout / no patches", - type=KerasModel, - ), - - # FIXME: Why is region_fl and region_fl_np the same model? - EynollahModelSpec( - category="region_fl", - variant='', - # filename="models_eynollah/eynollah-full-regions-3+column_20210425", - # filename="models_eynollah/model_2_full_layout_new_trans", - # filename="models_eynollah/modelens_full_lay_1_3_031124", - # filename="models_eynollah/modelens_full_lay_13__3_19_241024", - # filename="models_eynollah/model_full_lay_13_241024", - # filename="models_eynollah/modelens_full_lay_13_17_231024", - # filename="models_eynollah/modelens_full_lay_1_2_221024", - # filename="models_eynollah/modelens_full_layout_24_till_28", - # filename="models_eynollah/model_2_full_layout_new_trans", - filename="models_eynollah/modelens_full_lay_1__4_3_091124", - dist=dist_url("layout"), - help="full layout / with patches", - type=KerasModel, - ), - - EynollahModelSpec( - category="reading_order", - variant='', - #filename="models_eynollah/model_mb_ro_aug_ens_11", - #filename="models_eynollah/model_step_3200000_mb_ro", - #filename="models_eynollah/model_ens_reading_order_machine_based", - #filename="models_eynollah/model_mb_ro_aug_ens_8", - #filename="models_eynollah/model_ens_reading_order_machine_based", - filename="models_eynollah/model_eynollah_reading_order_20250824", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="textline", - variant='', - #filename="models_eynollah/modelens_textline_1_4_16092024", - #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", - #filename="models_eynollah/modelens_textline_1_3_4_20240915", - #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", - #filename="models_eynollah/modelens_textline_9_12_13_14_15", - #filename="models_eynollah/eynollah-textline_20210425", - filename="models_eynollah/modelens_textline_0_1__2_4_16092024", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="textline", - variant='light', - #filename="models_eynollah/eynollah-textline_light_20210425", - filename="models_eynollah/modelens_textline_0_1__2_4_16092024", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="table", - variant='', - filename="models_eynollah/eynollah-tables_20210319", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="table", - variant='light', - filename="models_eynollah/modelens_table_0t4_201124", - dist=dist_url("layout"), - type=KerasModel, - ), - - EynollahModelSpec( - category="ocr", - variant='', - filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", - dist=dist_url("ocr"), - type=KerasModel, - ), - - EynollahModelSpec( - category="num_to_char", - variant='', - filename="models_eynollah/characters_org.txt", - dist=dist_url("ocr"), - type=KerasModel, - ), - - EynollahModelSpec( - category="characters", - variant='', - filename="models_eynollah/characters_org.txt", - dist=dist_url("ocr"), - type=List, - ), - - EynollahModelSpec( - category="ocr", - variant='tr', - filename="models_eynollah/model_eynollah_ocr_trocr_20250919", - dist=dist_url("trocr"), - type=KerasModel, - ), - - EynollahModelSpec( - category="trocr_processor", - variant='', - filename="models_eynollah/microsoft/trocr-base-printed", - dist=dist_url("trocr"), - type=KerasModel, - ), - - EynollahModelSpec( - category="trocr_processor", - variant='htr', - filename="models_eynollah/microsoft/trocr-base-handwritten", - dist=dist_url("trocr"), - type=TrOCRProcessor, - ), - -])# }}} - -class EynollahModelZoo(): - """ - Wrapper class that handles storage and loading of models for all eynollah runners. - """ - model_basedir: Path - specs: EynollahModelSpecSet - - def __init__( - self, - basedir: str, - model_overrides: Optional[List[Tuple[str, str, str]]]=None, - ) -> None: - self.model_basedir = Path(basedir) - self.logger = logging.getLogger('eynollah.model_zoo') - self.specs = deepcopy(DEFAULT_MODEL_SPECS) - if model_overrides: - self.override_models(*model_overrides) - self._loaded: Dict[str, AnyModel] = {} - - def override_models( - self, - *model_overrides: Tuple[str, str, str], - ): - """ - Override the default model versions - """ - for model_category, model_variant, model_filename in model_overrides: - spec = self.specs.get(model_category, model_variant) - self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename) - self.specs.get(model_category, model_variant).filename = model_filename - - def model_path( - self, - model_category: str, - model_variant: str = '', - absolute: bool = True, - ) -> Path: - """ - Translate model_{type,variant} tuple into an absolute (or relative) Path - """ - spec = self.specs.get(model_category, model_variant) - if not Path(spec.filename).is_absolute() and absolute: - model_path = Path(self.model_basedir).joinpath(spec.filename) - else: - model_path = Path(spec.filename) - return model_path - - def load_models( - self, - *all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]], - ) -> Dict: - """ - Load all models by calling load_model and return a dictionary mapping model_category to loaded model - """ - ret = {} - for load_args in all_load_args: - if isinstance(load_args, str): - ret[load_args] = self.load_model(load_args) - else: - ret[load_args[0]] = self.load_model(*load_args) - return ret - - def load_model( - self, - model_category: str, - model_variant: str = '', - ) -> AnyModel: - """ - Load any model - """ - model_path = self.model_path(model_category, model_variant) - if model_path.suffix == '.h5' and Path(model_path.stem).exists(): - # prefer SavedModel over HDF5 format if it exists - model_path = Path(model_path.stem) - if model_category == 'ocr': - model = self._load_ocr_model(variant=model_variant) - elif model_category == 'num_to_char': - model = self._load_num_to_char() - elif model_category == 'characters': - model = self._load_characters() - elif model_category == 'trocr_processor': - return TrOCRProcessor.from_pretrained(self.model_path(...)) - else: - try: - model = load_model(model_path, compile=False) - except Exception as e: - self.logger.exception(e) - model = load_model(model_path, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - self._loaded[model_category] = model - return model # type: ignore - - def get(self, model_category: str, model_type: Optional[Type[T]]=None) -> T: - if model_category not in self._loaded: - raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"') - ret = self._loaded[model_category] - if model_type: - assert isinstance(ret, model_type) - return ret # type: ignore # FIXME: convince typing that we're returning generic type - - def _load_ocr_model(self, variant: str) -> AnyModel: - """ - Load OCR model - """ - ocr_model_dir = self.model_path('ocr', variant) - if variant == 'tr': - return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) - else: - ocr_model = load_model(ocr_model_dir, compile=False) - assert isinstance(ocr_model, KerasModel) - return KerasModel( - ocr_model.get_layer(name = "image").input, # type: ignore - ocr_model.get_layer(name = "dense2").output) # type: ignore - - def _load_characters(self) -> List[str]: - """ - Load encoding for OCR - """ - with open(self.model_path('ocr') / self.model_path('num_to_char', absolute=False), "r") as config_file: - return json.load(config_file) - - def _load_num_to_char(self) -> StringLookup: - """ - Load decoder for OCR - """ - characters = self._load_characters() - # Mapping characters to integers. - char_to_num = StringLookup(vocabulary=characters, mask_token=None) - # Mapping integers back to original characters. - return StringLookup( - vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True - ) - - def __str__(self): - return str(json.dumps({ - 'basedir': str(self.model_basedir), - 'versions': self.specs, - }, indent=2)) - - def shutdown(self): - """ - Ensure that a loaded models is not referenced by ``self._loaded`` anymore - """ - if hasattr(self, '_loaded') and getattr(self, '_loaded'): - for needle in self._loaded: - if self._loaded[needle]: - del self._loaded[needle] - diff --git a/src/eynollah/model_zoo/__init__.py b/src/eynollah/model_zoo/__init__.py new file mode 100644 index 0000000..e1dc985 --- /dev/null +++ b/src/eynollah/model_zoo/__init__.py @@ -0,0 +1,4 @@ +__all__ = [ + 'EynollahModelZoo', +] +from .model_zoo import EynollahModelZoo diff --git a/src/eynollah/model_zoo/default_specs.py b/src/eynollah/model_zoo/default_specs.py new file mode 100644 index 0000000..e06c829 --- /dev/null +++ b/src/eynollah/model_zoo/default_specs.py @@ -0,0 +1,314 @@ +from .specs import EynollahModelSpec, EynollahModelSpecSet +from .types import KerasModel, TrOCRProcessor, List + +# NOTE: This needs to change whenever models/versions change +ZENODO = "https://zenodo.org/records/17295988/files" +MODELS_VERSION = "v0_7_0" + +def dist_url(dist_name: str) -> str: + return f'{ZENODO}/models_{dist_name}_{MODELS_VERSION}.zip' + +DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ + + EynollahModelSpec( + category="enhancement", + variant='', + filename="models_eynollah/eynollah-enhancement_20210425", + dists=['enhancement', 'layout'], + dist_url=dist_url("enhancement"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization", + variant='', + filename="models_eynollah/eynollah-binarization-hybrid_20230504", + dists=['layout', 'binarization'], + dist_url=dist_url("binarization"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization", + variant='20210309', + filename="models_eynollah/eynollah-binarization_20210309", + dists=['binarization'], + dist_url=dist_url("binarization"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization", + variant='augment', + filename="models_eynollah/eynollah-binarization_20210425", + dists=['binarization'], + dist_url=dist_url("binarization"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_1", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1", + dist_url=dist_url("binarization"), + dists=['binarization'], + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_2", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2", + dist_url=dist_url("binarization"), + dists=['binarization'], + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_3", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3", + dist_url=dist_url("binarization"), + dists=['binarization'], + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_4", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4", + dist_url=dist_url("binarization"), + dists=['binarization'], + type=KerasModel, + ), + + EynollahModelSpec( + category="col_classifier", + variant='', + filename="models_eynollah/eynollah-column-classifier_20210425", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="page", + variant='', + filename="models_eynollah/model_eynollah_page_extraction_20250915", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region", + variant='', + filename="models_eynollah/eynollah-main-regions-ensembled_20210425", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region", + variant='extract_only_images', + filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region", + variant='light', + filename="models_eynollah/eynollah-main-regions_20220314", + dist_url=dist_url("layout"), + help="early layout", + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region_p2", + variant='', + filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425", + dist_url=dist_url("layout"), + help="early layout, non-light, 2nd part", + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region_1_2", + variant='', + #filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n", + #filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8", + #filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18", + #filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige", + #filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige", + filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024", + dist_url=dist_url("layout"), + dists=['layout'], + help="early layout, light, 1-or-2-column", + type=KerasModel, + ), + + EynollahModelSpec( + category="region_fl_np", + variant='', + #'filename="models_eynollah/modelens_full_lay_1_3_031124", + #'filename="models_eynollah/modelens_full_lay_13__3_19_241024", + #'filename="models_eynollah/model_full_lay_13_241024", + #'filename="models_eynollah/modelens_full_lay_13_17_231024", + #'filename="models_eynollah/modelens_full_lay_1_2_221024", + #'filename="models_eynollah/eynollah-full-regions-1column_20210425", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist_url=dist_url("layout"), + help="full layout / no patches", + dists=['layout'], + type=KerasModel, + ), + + # FIXME: Why is region_fl and region_fl_np the same model? + EynollahModelSpec( + category="region_fl", + variant='', + # filename="models_eynollah/eynollah-full-regions-3+column_20210425", + # filename="models_eynollah/model_2_full_layout_new_trans", + # filename="models_eynollah/modelens_full_lay_1_3_031124", + # filename="models_eynollah/modelens_full_lay_13__3_19_241024", + # filename="models_eynollah/model_full_lay_13_241024", + # filename="models_eynollah/modelens_full_lay_13_17_231024", + # filename="models_eynollah/modelens_full_lay_1_2_221024", + # filename="models_eynollah/modelens_full_layout_24_till_28", + # filename="models_eynollah/model_2_full_layout_new_trans", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist_url=dist_url("layout"), + help="full layout / with patches", + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="reading_order", + variant='', + #filename="models_eynollah/model_mb_ro_aug_ens_11", + #filename="models_eynollah/model_step_3200000_mb_ro", + #filename="models_eynollah/model_ens_reading_order_machine_based", + #filename="models_eynollah/model_mb_ro_aug_ens_8", + #filename="models_eynollah/model_ens_reading_order_machine_based", + filename="models_eynollah/model_eynollah_reading_order_20250824", + dist_url=dist_url("reading_order"), + dists=['layout', 'reading_order'], + type=KerasModel, + ), + + EynollahModelSpec( + category="textline", + variant='', + #filename="models_eynollah/modelens_textline_1_4_16092024", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_1_3_4_20240915", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_9_12_13_14_15", + #filename="models_eynollah/eynollah-textline_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="textline", + variant='light', + #filename="models_eynollah/eynollah-textline_light_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="table", + variant='', + filename="models_eynollah/eynollah-tables_20210319", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="table", + variant='light', + filename="models_eynollah/modelens_table_0t4_201124", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="ocr", + variant='', + filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", + dist_url=dist_url("ocr"), + dists=['layout', 'ocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="ocr", + variant='degraded', + filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/", + help="slightly better at degraded Fraktur", + dist_url=dist_url("ocr"), + dists=['ocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="num_to_char", + variant='', + filename="characters_org.txt", + dist_url=dist_url("ocr"), + dists=['ocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="characters", + variant='', + filename="characters_org.txt", + dist_url=dist_url("ocr"), + dists=['ocr'], + type=list, + ), + + EynollahModelSpec( + category="ocr", + variant='tr', + filename="models_eynollah/model_eynollah_ocr_trocr_20250919", + dist_url=dist_url("trocr"), + help='much slower transformer-based', + dists=['trocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="trocr_processor", + variant='', + filename="models_eynollah/microsoft/trocr-base-printed", + dist_url=dist_url("trocr"), + dists=['trocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="trocr_processor", + variant='htr', + filename="models_eynollah/microsoft/trocr-base-handwritten", + dist_url=dist_url("trocr"), + dists=['trocr'], + type=TrOCRProcessor, + ), + +]) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py new file mode 100644 index 0000000..7cfaa3a --- /dev/null +++ b/src/eynollah/model_zoo/model_zoo.py @@ -0,0 +1,189 @@ +import json +import logging +from copy import deepcopy +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Type, Union + +from keras.layers import StringLookup +from keras.models import Model as KerasModel +from keras.models import load_model +from tabulate import tabulate +from transformers import TrOCRProcessor, VisionEncoderDecoderModel + +from ..patch_encoder import PatchEncoder, Patches +from .specs import EynollahModelSpecSet +from .default_specs import DEFAULT_MODEL_SPECS +from .types import AnyModel, T + + +class EynollahModelZoo: + """ + Wrapper class that handles storage and loading of models for all eynollah runners. + """ + + model_basedir: Path + specs: EynollahModelSpecSet + + def __init__( + self, + basedir: str, + model_overrides: Optional[List[Tuple[str, str, str]]] = None, + ) -> None: + self.model_basedir = Path(basedir) + self.logger = logging.getLogger('eynollah.model_zoo') + self.specs = deepcopy(DEFAULT_MODEL_SPECS) + if model_overrides: + self.override_models(*model_overrides) + self._loaded: Dict[str, AnyModel] = {} + + def override_models( + self, + *model_overrides: Tuple[str, str, str], + ): + """ + Override the default model versions + """ + for model_category, model_variant, model_filename in model_overrides: + spec = self.specs.get(model_category, model_variant) + self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename) + self.specs.get(model_category, model_variant).filename = model_filename + + def model_path( + self, + model_category: str, + model_variant: str = '', + absolute: bool = True, + ) -> Path: + """ + Translate model_{type,variant} tuple into an absolute (or relative) Path + """ + spec = self.specs.get(model_category, model_variant) + if spec.category in ('characters', 'num_to_char'): + return self.model_path('ocr') / spec.filename + if not Path(spec.filename).is_absolute() and absolute: + model_path = Path(self.model_basedir).joinpath(spec.filename) + else: + model_path = Path(spec.filename) + return model_path + + def load_models( + self, + *all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]], + ) -> Dict: + """ + Load all models by calling load_model and return a dictionary mapping model_category to loaded model + """ + ret = {} + for load_args in all_load_args: + if isinstance(load_args, str): + ret[load_args] = self.load_model(load_args) + else: + ret[load_args[0]] = self.load_model(*load_args) + return ret + + def load_model( + self, + model_category: str, + model_variant: str = '', + ) -> AnyModel: + """ + Load any model + """ + model_path = self.model_path(model_category, model_variant) + if model_path.suffix == '.h5' and Path(model_path.stem).exists(): + # prefer SavedModel over HDF5 format if it exists + model_path = Path(model_path.stem) + if model_category == 'ocr': + model = self._load_ocr_model(variant=model_variant) + elif model_category == 'num_to_char': + model = self._load_num_to_char() + elif model_category == 'characters': + model = self._load_characters() + elif model_category == 'trocr_processor': + return TrOCRProcessor.from_pretrained(self.model_path(...)) + else: + try: + model = load_model(model_path, compile=False) + except Exception as e: + self.logger.exception(e) + model = load_model( + model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches} + ) + self._loaded[model_category] = model + return model # type: ignore + + def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T: + if model_category not in self._loaded: + raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"') + ret = self._loaded[model_category] + if model_type: + assert isinstance(ret, model_type) + return ret # type: ignore # FIXME: convince typing that we're returning generic type + + def _load_ocr_model(self, variant: str) -> AnyModel: + """ + Load OCR model + """ + ocr_model_dir = self.model_path('ocr', variant) + if variant == 'tr': + return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) + else: + ocr_model = load_model(ocr_model_dir, compile=False) + assert isinstance(ocr_model, KerasModel) + return KerasModel( + ocr_model.get_layer(name="image").input, # type: ignore + ocr_model.get_layer(name="dense2").output, # type: ignore + ) + + def _load_characters(self) -> List[str]: + """ + Load encoding for OCR + """ + with open(self.model_path('num_to_char'), "r") as config_file: + return json.load(config_file) + + def _load_num_to_char(self) -> StringLookup: + """ + Load decoder for OCR + """ + characters = self._load_characters() + # Mapping characters to integers. + char_to_num = StringLookup(vocabulary=characters, mask_token=None) + # Mapping integers back to original characters. + return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True) + + def __str__(self): + return tabulate( + [ + [ + spec.type.__name__, + spec.category, + spec.variant, + spec.help, + ', '.join(spec.dists), + f'Yes, at {self.model_path(spec.category, spec.variant)}' + if self.model_path(spec.category, spec.variant).exists() + else f'No, download {spec.dist_url}', + # self.model_path(spec.category, spec.variant), + ] + for spec in sorted(self.specs.specs, key=lambda x: x.category + '0' + x.variant) + ], + headers=[ + 'Type', + 'Category', + 'Variant', + 'Help', + 'Used in', + 'Installed', + ], + tablefmt='github', + ) + + def shutdown(self): + """ + Ensure that a loaded models is not referenced by ``self._loaded`` anymore + """ + if hasattr(self, '_loaded') and getattr(self, '_loaded'): + for needle in self._loaded: + if self._loaded[needle]: + del self._loaded[needle] diff --git a/src/eynollah/model_zoo/specs.py b/src/eynollah/model_zoo/specs.py new file mode 100644 index 0000000..4f8cffa --- /dev/null +++ b/src/eynollah/model_zoo/specs.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import Dict, List, Set, Tuple, Type +from .types import AnyModel + + +@dataclass +class EynollahModelSpec(): + """ + Describing a single model abstractly. + """ + category: str + # Relative filename to the models_eynollah directory in the dists + filename: str + # basename of the ZIP files that should contain this model + dists: List[str] + # URL to the smallest model distribution containing this model (link to Zenodo) + dist_url: str + type: Type[AnyModel] + variant: str = '' + help: str = '' + +class EynollahModelSpecSet(): + """ + List of all used models for eynollah. + """ + specs: List[EynollahModelSpec] + + def __init__(self, specs: List[EynollahModelSpec]) -> None: + self.specs = specs + self.categories: Set[str] = set([spec.category for spec in self.specs]) + self.variants: Dict[str, Set[str]] = { + spec.category: set([x.variant for x in self.specs if x.category == spec.category]) + for spec in self.specs + } + self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = { + (spec.category, spec.variant): spec + for spec in self.specs + } + + def asdict(self) -> Dict[str, Dict[str, str]]: + return { + spec.category: { + spec.variant: spec.filename + } + for spec in self.specs + } + + def get(self, category: str, variant: str) -> EynollahModelSpec: + if category not in self.categories: + raise ValueError(f"Unknown category '{category}', must be one of {self.categories}") + if variant not in self.variants[category]: + raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}") + return self._index_category_variant[(category, variant)] + + diff --git a/src/eynollah/model_zoo/types.py b/src/eynollah/model_zoo/types.py new file mode 100644 index 0000000..5c3685e --- /dev/null +++ b/src/eynollah/model_zoo/types.py @@ -0,0 +1,6 @@ +from typing import List, TypeVar, Union +from keras.models import Model as KerasModel +from transformers import TrOCRProcessor, VisionEncoderDecoderModel + +AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List] +T = TypeVar('T') diff --git a/src/eynollah/sbb_binarize.py b/src/eynollah/sbb_binarize.py index f8898a1..48dc7b1 100644 --- a/src/eynollah/sbb_binarize.py +++ b/src/eynollah/sbb_binarize.py @@ -24,7 +24,7 @@ def resize_image(img_in, input_height, input_width): class SbbBinarizer: - def __init__(self, model_dir, mode='single', logger=None): + def __init__(self, model_dir: str, mode: str, logger=None): if mode not in ('single', 'multi'): raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}") self.log = logger if logger else logging.getLogger('SbbBinarizer')