From d94285b3ea2972fa2a402bc3b88222a2d6d4a164 Mon Sep 17 00:00:00 2001 From: kba Date: Wed, 22 Oct 2025 13:07:35 +0200 Subject: [PATCH] rewrite model spec data structure --- src/eynollah/model_zoo.py | 453 +++++++++++++++++++++++++------------- 1 file changed, 306 insertions(+), 147 deletions(-) diff --git a/src/eynollah/model_zoo.py b/src/eynollah/model_zoo.py index 100d974..6bb06d3 100644 --- a/src/eynollah/model_zoo.py +++ b/src/eynollah/model_zoo.py @@ -1,154 +1,329 @@ +from copy import deepcopy from dataclasses import dataclass import json import logging from pathlib import Path -from typing import Dict, Literal, Optional, Tuple, List, Type, TypeVar, Union -from copy import deepcopy +from typing import Dict, Optional, Set, Tuple, List, Type, TypeVar, Union from keras.layers import StringLookup -from keras.models import Model, load_model +from keras.models import Model as KerasModel, load_model from transformers import TrOCRProcessor, VisionEncoderDecoderModel from eynollah.patch_encoder import PatchEncoder, Patches -SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model, List] +AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List] T = TypeVar('T') -# Dict mapping model_category to dict mapping variant (default is '') to Path -DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = { +# NOTE: This needs to change whenever models change +ZENODO = "https://zenodo.org/records/17295988/files" +MODELS_VERSION = "v0_7_0" - "enhancement": { - '': "eynollah-enhancement_20210425" - }, +def dist_url(dist_name: str) -> str: + return f'{ZENODO}/models_{dist_name}_${MODELS_VERSION}.zip' - "binarization": { - '': "eynollah-binarization_20210425" - }, +@dataclass +class EynollahModelSpec(): + """ + Describing a single model abstractly. + """ + category: str + # Relative filename to the models_eynollah directory in the dists + filename: str + # The smallest model distribution containing this model (link to Zenodo) + dist: str + type: Type[AnyModel] + variant: str = '' + help: str = '' - "binarization_multi_1": { - '': "saved_model_2020_01_16/model_bin1", - }, - "binarization_multi_2": { - '': "saved_model_2020_01_16/model_bin2", - }, - "binarization_multi_3": { - '': "saved_model_2020_01_16/model_bin3", - }, - "binarization_multi_4": { - '': "saved_model_2020_01_16/model_bin4", - }, +class EynollahModelSpecSet(): + """ + List of all used models for eynollah. + """ + specs: List[EynollahModelSpec] - "col_classifier": { - '': "eynollah-column-classifier_20210425", - }, + def __init__(self, specs: List[EynollahModelSpec]) -> None: + self.specs = specs + self.categories: Set[str] = set([spec.category for spec in self.specs]) + self.variants: Dict[str, Set[str]] = { + spec.category: set([x.variant for x in self.specs if x.category == spec.category]) + for spec in self.specs + } + self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = { + (spec.category, spec.variant): spec + for spec in self.specs + } - "page": { - '': "model_eynollah_page_extraction_20250915", - }, + def asdict(self) -> Dict[str, Dict[str, str]]: + return { + spec.category: { + spec.variant: spec.filename + } + for spec in self.specs + } - # TODO: What is this commented out model? - #?: "eynollah-main-regions-aug-scaling_20210425", + def get(self, category: str, variant: str) -> EynollahModelSpec: + if category not in self.categories: + raise ValueError(f"Unknown category '{category}', must be one of {self.categories}") + if variant not in self.variants[category]: + raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}") + return self._index_category_variant[(category, variant)] - # early layout - "region": { - '': "eynollah-main-regions-ensembled_20210425", - 'extract_only_images': "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", - 'light': "eynollah-main-regions_20220314", - }, +DEFAULT_MODEL_SPECS = EynollahModelSpecSet([# {{{ - # early layout, non-light, 2nd part - "region_p2": { - '': "eynollah-main-regions-aug-rotation_20210425", - }, + EynollahModelSpec( + category="enhancement", + variant='', + filename="models_eynollah/eynollah-enhancement_20210425", + dist=dist_url("enhancement"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization", + variant='', + filename="models_eynollah/eynollah-binarization_20210425", + dist=dist_url("binarization"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_1", + variant='', + filename="models_eynollah/saved_model_2020_01_16/model_bin1", + dist=dist_url("binarization"), + type=KerasModel, + ), - # early layout, light, 1-or-2-column - "region_1_2": { - #'': "modelens_12sp_elay_0_3_4__3_6_n" - #'': "modelens_earlylayout_12spaltige_2_3_5_6_7_8" - #'': "modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18" - #'': "modelens_1_2_4_5_early_lay_1_2_spaltige" - #'': "model_3_eraly_layout_no_patches_1_2_spaltige" - '': "modelens_e_l_all_sp_0_1_2_3_4_171024" - }, + EynollahModelSpec( + category="binarization_multi_2", + variant='', + filename="models_eynollah/saved_model_2020_01_16/model_bin2", + dist=dist_url("binarization"), + type=KerasModel, + ), - # full layout / no patches - "region_fl_np": { - #'': "modelens_full_lay_1_3_031124" - #'': "modelens_full_lay_13__3_19_241024" - #'': "model_full_lay_13_241024" - #'': "modelens_full_lay_13_17_231024" - #'': "modelens_full_lay_1_2_221024" - #'': "eynollah-full-regions-1column_20210425" - '': "modelens_full_lay_1__4_3_091124" - }, + EynollahModelSpec( + category="binarization_multi_3", + variant='', + filename="models_eynollah/saved_model_2020_01_16/model_bin3", + dist=dist_url("binarization"), + type=KerasModel, + ), - # full layout / with patches - "region_fl": { - #'': "eynollah-full-regions-3+column_20210425" - #'': #"model_2_full_layout_new_trans" - #'': "modelens_full_lay_1_3_031124" - #'': "modelens_full_lay_13__3_19_241024" - #'': "model_full_lay_13_241024" - #'': "modelens_full_lay_13_17_231024" - #'': "modelens_full_lay_1_2_221024" - #'': "modelens_full_layout_24_till_28" - #'': "model_2_full_layout_new_trans" - '': "modelens_full_lay_1__4_3_091124", - }, + EynollahModelSpec( + category="binarization_multi_4", + variant='', + filename="models_eynollah/saved_model_2020_01_16/model_bin4", + dist=dist_url("binarization"), + type=KerasModel, + ), - "reading_order": { - #'': "model_mb_ro_aug_ens_11" - #'': "model_step_3200000_mb_ro" - #'': "model_ens_reading_order_machine_based" - #'': "model_mb_ro_aug_ens_8" - #'': "model_ens_reading_order_machine_based" - '': "model_eynollah_reading_order_20250824" - }, + EynollahModelSpec( + category="col_classifier", + variant='', + filename="models_eynollah/eynollah-column-classifier_20210425", + dist=dist_url("layout"), + type=KerasModel, + ), - "textline": { - #'light': "eynollah-textline_light_20210425" - 'light': "modelens_textline_0_1__2_4_16092024", - #'': "modelens_textline_1_4_16092024" - #'': "model_textline_ens_3_4_5_6_artificial" - #'': "modelens_textline_1_3_4_20240915" - #'': "model_textline_ens_3_4_5_6_artificial" - #'': "modelens_textline_9_12_13_14_15" - #'': "eynollah-textline_20210425" - '': "modelens_textline_0_1__2_4_16092024" - }, + EynollahModelSpec( + category="page", + variant='', + filename="models_eynollah/model_eynollah_page_extraction_20250915", + dist=dist_url("layout"), + type=KerasModel, + ), - "table": { - 'light': "modelens_table_0t4_201124", - '': "eynollah-tables_20210319", - }, + EynollahModelSpec( + category="region", + variant='', + filename="models_eynollah/eynollah-main-regions-ensembled_20210425", + dist=dist_url("layout"), + type=KerasModel, + ), - "ocr": { - 'tr': "model_eynollah_ocr_trocr_20250919", - '': "model_eynollah_ocr_cnnrnn_20250930", - }, + EynollahModelSpec( + category="region", + variant='extract_only_images', + filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", + dist=dist_url("layout"), + type=KerasModel, + ), - 'trocr_processor': { - '': 'microsoft/trocr-base-printed', - 'htr': "microsoft/trocr-base-handwritten", - }, + EynollahModelSpec( + category="region", + variant='light', + filename="models_eynollah/eynollah-main-regions_20220314", + dist=dist_url("layout"), + help="early layout", + type=KerasModel, + ), - 'num_to_char': { - '': 'characters_org.txt' - }, + EynollahModelSpec( + category="region_p2", + variant='', + filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425", + dist=dist_url("layout"), + help="early layout, non-light, 2nd part", + type=KerasModel, + ), - 'characters': { - '': 'characters_org.txt' - }, + EynollahModelSpec( + category="region_1_2", + variant='', + #filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n", + #filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8", + #filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18", + #filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige", + #filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige", + filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024", + dist=dist_url("layout"), + help="early layout, light, 1-or-2-column", + type=KerasModel, + ), -} + EynollahModelSpec( + category="region_fl_np", + variant='', + #'filename="models_eynollah/modelens_full_lay_1_3_031124", + #'filename="models_eynollah/modelens_full_lay_13__3_19_241024", + #'filename="models_eynollah/model_full_lay_13_241024", + #'filename="models_eynollah/modelens_full_lay_13_17_231024", + #'filename="models_eynollah/modelens_full_lay_1_2_221024", + #'filename="models_eynollah/eynollah-full-regions-1column_20210425", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist=dist_url("layout"), + help="full layout / no patches", + type=KerasModel, + ), + # FIXME: Why is region_fl and region_fl_np the same model? + EynollahModelSpec( + category="region_fl", + variant='', + # filename="models_eynollah/eynollah-full-regions-3+column_20210425", + # filename="models_eynollah/model_2_full_layout_new_trans", + # filename="models_eynollah/modelens_full_lay_1_3_031124", + # filename="models_eynollah/modelens_full_lay_13__3_19_241024", + # filename="models_eynollah/model_full_lay_13_241024", + # filename="models_eynollah/modelens_full_lay_13_17_231024", + # filename="models_eynollah/modelens_full_lay_1_2_221024", + # filename="models_eynollah/modelens_full_layout_24_till_28", + # filename="models_eynollah/model_2_full_layout_new_trans", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist=dist_url("layout"), + help="full layout / with patches", + type=KerasModel, + ), + + EynollahModelSpec( + category="reading_order", + variant='', + #filename="models_eynollah/model_mb_ro_aug_ens_11", + #filename="models_eynollah/model_step_3200000_mb_ro", + #filename="models_eynollah/model_ens_reading_order_machine_based", + #filename="models_eynollah/model_mb_ro_aug_ens_8", + #filename="models_eynollah/model_ens_reading_order_machine_based", + filename="models_eynollah/model_eynollah_reading_order_20250824", + dist=dist_url("layout"), + type=KerasModel, + ), + + EynollahModelSpec( + category="textline", + variant='', + #filename="models_eynollah/modelens_textline_1_4_16092024", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_1_3_4_20240915", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_9_12_13_14_15", + #filename="models_eynollah/eynollah-textline_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist=dist_url("layout"), + type=KerasModel, + ), + + EynollahModelSpec( + category="textline", + variant='light', + #filename="models_eynollah/eynollah-textline_light_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist=dist_url("layout"), + type=KerasModel, + ), + + EynollahModelSpec( + category="table", + variant='', + filename="models_eynollah/eynollah-tables_20210319", + dist=dist_url("layout"), + type=KerasModel, + ), + + EynollahModelSpec( + category="table", + variant='light', + filename="models_eynollah/modelens_table_0t4_201124", + dist=dist_url("layout"), + type=KerasModel, + ), + + EynollahModelSpec( + category="ocr", + variant='', + filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", + dist=dist_url("ocr"), + type=KerasModel, + ), + + EynollahModelSpec( + category="num_to_char", + variant='', + filename="models_eynollah/characters_org.txt", + dist=dist_url("ocr"), + type=KerasModel, + ), + + EynollahModelSpec( + category="characters", + variant='', + filename="models_eynollah/characters_org.txt", + dist=dist_url("ocr"), + type=List, + ), + + EynollahModelSpec( + category="ocr", + variant='tr', + filename="models_eynollah/model_eynollah_ocr_trocr_20250919", + dist=dist_url("trocr"), + type=KerasModel, + ), + + EynollahModelSpec( + category="trocr_processor", + variant='', + filename="models_eynollah/microsoft/trocr-base-printed", + dist=dist_url("trocr"), + type=KerasModel, + ), + + EynollahModelSpec( + category="trocr_processor", + variant='htr', + filename="models_eynollah/microsoft/trocr-base-handwritten", + dist=dist_url("trocr"), + type=TrOCRProcessor, + ), + +])# }}} class EynollahModelZoo(): """ Wrapper class that handles storage and loading of models for all eynollah runners. """ model_basedir: Path - model_versions: dict + specs: EynollahModelSpecSet def __init__( self, @@ -157,10 +332,10 @@ class EynollahModelZoo(): ) -> None: self.model_basedir = Path(basedir) self.logger = logging.getLogger('eynollah.model_zoo') - self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS) + self.specs = deepcopy(DEFAULT_MODEL_SPECS) if model_overrides: self.override_models(*model_overrides) - self._loaded: Dict[str, SomeEynollahModel] = {} + self._loaded: Dict[str, AnyModel] = {} def override_models( self, @@ -170,39 +345,24 @@ class EynollahModelZoo(): Override the default model versions """ for model_category, model_variant, model_filename in model_overrides: - if model_category not in DEFAULT_MODEL_VERSIONS: - raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}") - if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]: - raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}") - self.logger.warning( - "Overriding default model %s ('%s' variant) from %s to %s", - model_category, - model_variant, - DEFAULT_MODEL_VERSIONS[model_category][model_variant], - model_filename - ) - self.model_versions[model_category][model_variant] = model_filename + spec = self.specs.get(model_category, model_variant) + self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename) + self.specs.get(model_category, model_variant).filename = model_filename def model_path( self, model_category: str, model_variant: str = '', - model_filename: str = '', absolute: bool = True, ) -> Path: """ - Translate model_{type,variant,filename} tuple into an absolute (or relative) Path + Translate model_{type,variant} tuple into an absolute (or relative) Path """ - if model_category not in DEFAULT_MODEL_VERSIONS: - raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}") - if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]: - raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}") - if not model_filename: - model_filename = DEFAULT_MODEL_VERSIONS[model_category][model_variant] - if not Path(model_filename).is_absolute() and absolute: - model_path = Path(self.model_basedir).joinpath(model_filename) + spec = self.specs.get(model_category, model_variant) + if not Path(spec.filename).is_absolute() and absolute: + model_path = Path(self.model_basedir).joinpath(spec.filename) else: - model_path = Path(model_filename) + model_path = Path(spec.filename) return model_path def load_models( @@ -224,12 +384,11 @@ class EynollahModelZoo(): self, model_category: str, model_variant: str = '', - model_filename: str = '', - ) -> SomeEynollahModel: + ) -> AnyModel: """ Load any model """ - model_path = self.model_path(model_category, model_variant, model_filename) + model_path = self.model_path(model_category, model_variant) if model_path.suffix == '.h5' and Path(model_path.stem).exists(): # prefer SavedModel over HDF5 format if it exists model_path = Path(model_path.stem) @@ -259,17 +418,17 @@ class EynollahModelZoo(): assert isinstance(ret, model_type) return ret # type: ignore # FIXME: convince typing that we're returning generic type - def _load_ocr_model(self, variant: str) -> SomeEynollahModel: + def _load_ocr_model(self, variant: str) -> AnyModel: """ Load OCR model """ - ocr_model_dir = Path(self.model_basedir, self.model_versions["ocr"][variant]) + ocr_model_dir = self.model_path('ocr', variant) if variant == 'tr': return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) else: ocr_model = load_model(ocr_model_dir, compile=False) - assert isinstance(ocr_model, Model) - return Model( + assert isinstance(ocr_model, KerasModel) + return KerasModel( ocr_model.get_layer(name = "image").input, # type: ignore ocr_model.get_layer(name = "dense2").output) # type: ignore @@ -279,7 +438,7 @@ class EynollahModelZoo(): """ with open(self.model_path('ocr') / self.model_path('num_to_char', absolute=False), "r") as config_file: return json.load(config_file) - + def _load_num_to_char(self) -> StringLookup: """ Load decoder for OCR @@ -295,7 +454,7 @@ class EynollahModelZoo(): def __str__(self): return str(json.dumps({ 'basedir': str(self.model_basedir), - 'versions': self.model_versions, + 'versions': self.specs, }, indent=2)) def shutdown(self):