rewrite model spec data structure

This commit is contained in:
kba 2025-10-22 13:07:35 +02:00
parent 146658f026
commit d94285b3ea

View file

@ -1,154 +1,329 @@
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Literal, Optional, Tuple, List, Type, TypeVar, Union from typing import Dict, Optional, Set, Tuple, List, Type, TypeVar, Union
from copy import deepcopy
from keras.layers import StringLookup from keras.layers import StringLookup
from keras.models import Model, load_model from keras.models import Model as KerasModel, load_model
from transformers import TrOCRProcessor, VisionEncoderDecoderModel from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from eynollah.patch_encoder import PatchEncoder, Patches from eynollah.patch_encoder import PatchEncoder, Patches
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model, List] AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List]
T = TypeVar('T') T = TypeVar('T')
# Dict mapping model_category to dict mapping variant (default is '') to Path # NOTE: This needs to change whenever models change
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = { ZENODO = "https://zenodo.org/records/17295988/files"
MODELS_VERSION = "v0_7_0"
"enhancement": { def dist_url(dist_name: str) -> str:
'': "eynollah-enhancement_20210425" return f'{ZENODO}/models_{dist_name}_${MODELS_VERSION}.zip'
},
"binarization": { @dataclass
'': "eynollah-binarization_20210425" class EynollahModelSpec():
}, """
Describing a single model abstractly.
"""
category: str
# Relative filename to the models_eynollah directory in the dists
filename: str
# The smallest model distribution containing this model (link to Zenodo)
dist: str
type: Type[AnyModel]
variant: str = ''
help: str = ''
"binarization_multi_1": { class EynollahModelSpecSet():
'': "saved_model_2020_01_16/model_bin1", """
}, List of all used models for eynollah.
"binarization_multi_2": { """
'': "saved_model_2020_01_16/model_bin2", specs: List[EynollahModelSpec]
},
"binarization_multi_3": {
'': "saved_model_2020_01_16/model_bin3",
},
"binarization_multi_4": {
'': "saved_model_2020_01_16/model_bin4",
},
"col_classifier": { def __init__(self, specs: List[EynollahModelSpec]) -> None:
'': "eynollah-column-classifier_20210425", self.specs = specs
}, self.categories: Set[str] = set([spec.category for spec in self.specs])
self.variants: Dict[str, Set[str]] = {
spec.category: set([x.variant for x in self.specs if x.category == spec.category])
for spec in self.specs
}
self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = {
(spec.category, spec.variant): spec
for spec in self.specs
}
"page": { def asdict(self) -> Dict[str, Dict[str, str]]:
'': "model_eynollah_page_extraction_20250915", return {
}, spec.category: {
spec.variant: spec.filename
}
for spec in self.specs
}
# TODO: What is this commented out model? def get(self, category: str, variant: str) -> EynollahModelSpec:
#?: "eynollah-main-regions-aug-scaling_20210425", if category not in self.categories:
raise ValueError(f"Unknown category '{category}', must be one of {self.categories}")
if variant not in self.variants[category]:
raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}")
return self._index_category_variant[(category, variant)]
# early layout DEFAULT_MODEL_SPECS = EynollahModelSpecSet([# {{{
"region": {
'': "eynollah-main-regions-ensembled_20210425",
'extract_only_images': "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
'light': "eynollah-main-regions_20220314",
},
# early layout, non-light, 2nd part EynollahModelSpec(
"region_p2": { category="enhancement",
'': "eynollah-main-regions-aug-rotation_20210425", variant='',
}, filename="models_eynollah/eynollah-enhancement_20210425",
dist=dist_url("enhancement"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization",
variant='',
filename="models_eynollah/eynollah-binarization_20210425",
dist=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_1",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin1",
dist=dist_url("binarization"),
type=KerasModel,
),
# early layout, light, 1-or-2-column EynollahModelSpec(
"region_1_2": { category="binarization_multi_2",
#'': "modelens_12sp_elay_0_3_4__3_6_n" variant='',
#'': "modelens_earlylayout_12spaltige_2_3_5_6_7_8" filename="models_eynollah/saved_model_2020_01_16/model_bin2",
#'': "modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18" dist=dist_url("binarization"),
#'': "modelens_1_2_4_5_early_lay_1_2_spaltige" type=KerasModel,
#'': "model_3_eraly_layout_no_patches_1_2_spaltige" ),
'': "modelens_e_l_all_sp_0_1_2_3_4_171024"
},
# full layout / no patches EynollahModelSpec(
"region_fl_np": { category="binarization_multi_3",
#'': "modelens_full_lay_1_3_031124" variant='',
#'': "modelens_full_lay_13__3_19_241024" filename="models_eynollah/saved_model_2020_01_16/model_bin3",
#'': "model_full_lay_13_241024" dist=dist_url("binarization"),
#'': "modelens_full_lay_13_17_231024" type=KerasModel,
#'': "modelens_full_lay_1_2_221024" ),
#'': "eynollah-full-regions-1column_20210425"
'': "modelens_full_lay_1__4_3_091124"
},
# full layout / with patches EynollahModelSpec(
"region_fl": { category="binarization_multi_4",
#'': "eynollah-full-regions-3+column_20210425" variant='',
#'': #"model_2_full_layout_new_trans" filename="models_eynollah/saved_model_2020_01_16/model_bin4",
#'': "modelens_full_lay_1_3_031124" dist=dist_url("binarization"),
#'': "modelens_full_lay_13__3_19_241024" type=KerasModel,
#'': "model_full_lay_13_241024" ),
#'': "modelens_full_lay_13_17_231024"
#'': "modelens_full_lay_1_2_221024"
#'': "modelens_full_layout_24_till_28"
#'': "model_2_full_layout_new_trans"
'': "modelens_full_lay_1__4_3_091124",
},
"reading_order": { EynollahModelSpec(
#'': "model_mb_ro_aug_ens_11" category="col_classifier",
#'': "model_step_3200000_mb_ro" variant='',
#'': "model_ens_reading_order_machine_based" filename="models_eynollah/eynollah-column-classifier_20210425",
#'': "model_mb_ro_aug_ens_8" dist=dist_url("layout"),
#'': "model_ens_reading_order_machine_based" type=KerasModel,
'': "model_eynollah_reading_order_20250824" ),
},
"textline": { EynollahModelSpec(
#'light': "eynollah-textline_light_20210425" category="page",
'light': "modelens_textline_0_1__2_4_16092024", variant='',
#'': "modelens_textline_1_4_16092024" filename="models_eynollah/model_eynollah_page_extraction_20250915",
#'': "model_textline_ens_3_4_5_6_artificial" dist=dist_url("layout"),
#'': "modelens_textline_1_3_4_20240915" type=KerasModel,
#'': "model_textline_ens_3_4_5_6_artificial" ),
#'': "modelens_textline_9_12_13_14_15"
#'': "eynollah-textline_20210425"
'': "modelens_textline_0_1__2_4_16092024"
},
"table": { EynollahModelSpec(
'light': "modelens_table_0t4_201124", category="region",
'': "eynollah-tables_20210319", variant='',
}, filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
dist=dist_url("layout"),
type=KerasModel,
),
"ocr": { EynollahModelSpec(
'tr': "model_eynollah_ocr_trocr_20250919", category="region",
'': "model_eynollah_ocr_cnnrnn_20250930", variant='extract_only_images',
}, filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
dist=dist_url("layout"),
type=KerasModel,
),
'trocr_processor': { EynollahModelSpec(
'': 'microsoft/trocr-base-printed', category="region",
'htr': "microsoft/trocr-base-handwritten", variant='light',
}, filename="models_eynollah/eynollah-main-regions_20220314",
dist=dist_url("layout"),
help="early layout",
type=KerasModel,
),
'num_to_char': { EynollahModelSpec(
'': 'characters_org.txt' category="region_p2",
}, variant='',
filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425",
dist=dist_url("layout"),
help="early layout, non-light, 2nd part",
type=KerasModel,
),
'characters': { EynollahModelSpec(
'': 'characters_org.txt' category="region_1_2",
}, variant='',
#filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n",
#filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8",
#filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18",
#filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige",
#filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige",
filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024",
dist=dist_url("layout"),
help="early layout, light, 1-or-2-column",
type=KerasModel,
),
} EynollahModelSpec(
category="region_fl_np",
variant='',
#'filename="models_eynollah/modelens_full_lay_1_3_031124",
#'filename="models_eynollah/modelens_full_lay_13__3_19_241024",
#'filename="models_eynollah/model_full_lay_13_241024",
#'filename="models_eynollah/modelens_full_lay_13_17_231024",
#'filename="models_eynollah/modelens_full_lay_1_2_221024",
#'filename="models_eynollah/eynollah-full-regions-1column_20210425",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist=dist_url("layout"),
help="full layout / no patches",
type=KerasModel,
),
# FIXME: Why is region_fl and region_fl_np the same model?
EynollahModelSpec(
category="region_fl",
variant='',
# filename="models_eynollah/eynollah-full-regions-3+column_20210425",
# filename="models_eynollah/model_2_full_layout_new_trans",
# filename="models_eynollah/modelens_full_lay_1_3_031124",
# filename="models_eynollah/modelens_full_lay_13__3_19_241024",
# filename="models_eynollah/model_full_lay_13_241024",
# filename="models_eynollah/modelens_full_lay_13_17_231024",
# filename="models_eynollah/modelens_full_lay_1_2_221024",
# filename="models_eynollah/modelens_full_layout_24_till_28",
# filename="models_eynollah/model_2_full_layout_new_trans",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist=dist_url("layout"),
help="full layout / with patches",
type=KerasModel,
),
EynollahModelSpec(
category="reading_order",
variant='',
#filename="models_eynollah/model_mb_ro_aug_ens_11",
#filename="models_eynollah/model_step_3200000_mb_ro",
#filename="models_eynollah/model_ens_reading_order_machine_based",
#filename="models_eynollah/model_mb_ro_aug_ens_8",
#filename="models_eynollah/model_ens_reading_order_machine_based",
filename="models_eynollah/model_eynollah_reading_order_20250824",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="textline",
variant='',
#filename="models_eynollah/modelens_textline_1_4_16092024",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_1_3_4_20240915",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_9_12_13_14_15",
#filename="models_eynollah/eynollah-textline_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="textline",
variant='light',
#filename="models_eynollah/eynollah-textline_light_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="table",
variant='',
filename="models_eynollah/eynollah-tables_20210319",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="table",
variant='light',
filename="models_eynollah/modelens_table_0t4_201124",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="ocr",
variant='',
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
dist=dist_url("ocr"),
type=KerasModel,
),
EynollahModelSpec(
category="num_to_char",
variant='',
filename="models_eynollah/characters_org.txt",
dist=dist_url("ocr"),
type=KerasModel,
),
EynollahModelSpec(
category="characters",
variant='',
filename="models_eynollah/characters_org.txt",
dist=dist_url("ocr"),
type=List,
),
EynollahModelSpec(
category="ocr",
variant='tr',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist=dist_url("trocr"),
type=KerasModel,
),
EynollahModelSpec(
category="trocr_processor",
variant='',
filename="models_eynollah/microsoft/trocr-base-printed",
dist=dist_url("trocr"),
type=KerasModel,
),
EynollahModelSpec(
category="trocr_processor",
variant='htr',
filename="models_eynollah/microsoft/trocr-base-handwritten",
dist=dist_url("trocr"),
type=TrOCRProcessor,
),
])# }}}
class EynollahModelZoo(): class EynollahModelZoo():
""" """
Wrapper class that handles storage and loading of models for all eynollah runners. Wrapper class that handles storage and loading of models for all eynollah runners.
""" """
model_basedir: Path model_basedir: Path
model_versions: dict specs: EynollahModelSpecSet
def __init__( def __init__(
self, self,
@ -157,10 +332,10 @@ class EynollahModelZoo():
) -> None: ) -> None:
self.model_basedir = Path(basedir) self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo') self.logger = logging.getLogger('eynollah.model_zoo')
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS) self.specs = deepcopy(DEFAULT_MODEL_SPECS)
if model_overrides: if model_overrides:
self.override_models(*model_overrides) self.override_models(*model_overrides)
self._loaded: Dict[str, SomeEynollahModel] = {} self._loaded: Dict[str, AnyModel] = {}
def override_models( def override_models(
self, self,
@ -170,39 +345,24 @@ class EynollahModelZoo():
Override the default model versions Override the default model versions
""" """
for model_category, model_variant, model_filename in model_overrides: for model_category, model_variant, model_filename in model_overrides:
if model_category not in DEFAULT_MODEL_VERSIONS: spec = self.specs.get(model_category, model_variant)
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}") self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]: self.specs.get(model_category, model_variant).filename = model_filename
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
self.logger.warning(
"Overriding default model %s ('%s' variant) from %s to %s",
model_category,
model_variant,
DEFAULT_MODEL_VERSIONS[model_category][model_variant],
model_filename
)
self.model_versions[model_category][model_variant] = model_filename
def model_path( def model_path(
self, self,
model_category: str, model_category: str,
model_variant: str = '', model_variant: str = '',
model_filename: str = '',
absolute: bool = True, absolute: bool = True,
) -> Path: ) -> Path:
""" """
Translate model_{type,variant,filename} tuple into an absolute (or relative) Path Translate model_{type,variant} tuple into an absolute (or relative) Path
""" """
if model_category not in DEFAULT_MODEL_VERSIONS: spec = self.specs.get(model_category, model_variant)
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}") if not Path(spec.filename).is_absolute() and absolute:
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]: model_path = Path(self.model_basedir).joinpath(spec.filename)
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
if not model_filename:
model_filename = DEFAULT_MODEL_VERSIONS[model_category][model_variant]
if not Path(model_filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(model_filename)
else: else:
model_path = Path(model_filename) model_path = Path(spec.filename)
return model_path return model_path
def load_models( def load_models(
@ -224,12 +384,11 @@ class EynollahModelZoo():
self, self,
model_category: str, model_category: str,
model_variant: str = '', model_variant: str = '',
model_filename: str = '', ) -> AnyModel:
) -> SomeEynollahModel:
""" """
Load any model Load any model
""" """
model_path = self.model_path(model_category, model_variant, model_filename) model_path = self.model_path(model_category, model_variant)
if model_path.suffix == '.h5' and Path(model_path.stem).exists(): if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists # prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem) model_path = Path(model_path.stem)
@ -259,17 +418,17 @@ class EynollahModelZoo():
assert isinstance(ret, model_type) assert isinstance(ret, model_type)
return ret # type: ignore # FIXME: convince typing that we're returning generic type return ret # type: ignore # FIXME: convince typing that we're returning generic type
def _load_ocr_model(self, variant: str) -> SomeEynollahModel: def _load_ocr_model(self, variant: str) -> AnyModel:
""" """
Load OCR model Load OCR model
""" """
ocr_model_dir = Path(self.model_basedir, self.model_versions["ocr"][variant]) ocr_model_dir = self.model_path('ocr', variant)
if variant == 'tr': if variant == 'tr':
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
else: else:
ocr_model = load_model(ocr_model_dir, compile=False) ocr_model = load_model(ocr_model_dir, compile=False)
assert isinstance(ocr_model, Model) assert isinstance(ocr_model, KerasModel)
return Model( return KerasModel(
ocr_model.get_layer(name = "image").input, # type: ignore ocr_model.get_layer(name = "image").input, # type: ignore
ocr_model.get_layer(name = "dense2").output) # type: ignore ocr_model.get_layer(name = "dense2").output) # type: ignore
@ -279,7 +438,7 @@ class EynollahModelZoo():
""" """
with open(self.model_path('ocr') / self.model_path('num_to_char', absolute=False), "r") as config_file: with open(self.model_path('ocr') / self.model_path('num_to_char', absolute=False), "r") as config_file:
return json.load(config_file) return json.load(config_file)
def _load_num_to_char(self) -> StringLookup: def _load_num_to_char(self) -> StringLookup:
""" """
Load decoder for OCR Load decoder for OCR
@ -295,7 +454,7 @@ class EynollahModelZoo():
def __str__(self): def __str__(self):
return str(json.dumps({ return str(json.dumps({
'basedir': str(self.model_basedir), 'basedir': str(self.model_basedir),
'versions': self.model_versions, 'versions': self.specs,
}, indent=2)) }, indent=2))
def shutdown(self): def shutdown(self):