rewrite model spec data structure

This commit is contained in:
kba 2025-10-22 13:07:35 +02:00
parent 146658f026
commit d94285b3ea

View file

@ -1,154 +1,329 @@
from copy import deepcopy
from dataclasses import dataclass
import json
import logging
from pathlib import Path
from typing import Dict, Literal, Optional, Tuple, List, Type, TypeVar, Union
from copy import deepcopy
from typing import Dict, Optional, Set, Tuple, List, Type, TypeVar, Union
from keras.layers import StringLookup
from keras.models import Model, load_model
from keras.models import Model as KerasModel, load_model
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from eynollah.patch_encoder import PatchEncoder, Patches
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model, List]
AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List]
T = TypeVar('T')
# Dict mapping model_category to dict mapping variant (default is '') to Path
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
# NOTE: This needs to change whenever models change
ZENODO = "https://zenodo.org/records/17295988/files"
MODELS_VERSION = "v0_7_0"
"enhancement": {
'': "eynollah-enhancement_20210425"
},
def dist_url(dist_name: str) -> str:
return f'{ZENODO}/models_{dist_name}_${MODELS_VERSION}.zip'
"binarization": {
'': "eynollah-binarization_20210425"
},
@dataclass
class EynollahModelSpec():
"""
Describing a single model abstractly.
"""
category: str
# Relative filename to the models_eynollah directory in the dists
filename: str
# The smallest model distribution containing this model (link to Zenodo)
dist: str
type: Type[AnyModel]
variant: str = ''
help: str = ''
"binarization_multi_1": {
'': "saved_model_2020_01_16/model_bin1",
},
"binarization_multi_2": {
'': "saved_model_2020_01_16/model_bin2",
},
"binarization_multi_3": {
'': "saved_model_2020_01_16/model_bin3",
},
"binarization_multi_4": {
'': "saved_model_2020_01_16/model_bin4",
},
class EynollahModelSpecSet():
"""
List of all used models for eynollah.
"""
specs: List[EynollahModelSpec]
"col_classifier": {
'': "eynollah-column-classifier_20210425",
},
def __init__(self, specs: List[EynollahModelSpec]) -> None:
self.specs = specs
self.categories: Set[str] = set([spec.category for spec in self.specs])
self.variants: Dict[str, Set[str]] = {
spec.category: set([x.variant for x in self.specs if x.category == spec.category])
for spec in self.specs
}
self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = {
(spec.category, spec.variant): spec
for spec in self.specs
}
"page": {
'': "model_eynollah_page_extraction_20250915",
},
def asdict(self) -> Dict[str, Dict[str, str]]:
return {
spec.category: {
spec.variant: spec.filename
}
for spec in self.specs
}
# TODO: What is this commented out model?
#?: "eynollah-main-regions-aug-scaling_20210425",
def get(self, category: str, variant: str) -> EynollahModelSpec:
if category not in self.categories:
raise ValueError(f"Unknown category '{category}', must be one of {self.categories}")
if variant not in self.variants[category]:
raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}")
return self._index_category_variant[(category, variant)]
# early layout
"region": {
'': "eynollah-main-regions-ensembled_20210425",
'extract_only_images': "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
'light': "eynollah-main-regions_20220314",
},
DEFAULT_MODEL_SPECS = EynollahModelSpecSet([# {{{
# early layout, non-light, 2nd part
"region_p2": {
'': "eynollah-main-regions-aug-rotation_20210425",
},
EynollahModelSpec(
category="enhancement",
variant='',
filename="models_eynollah/eynollah-enhancement_20210425",
dist=dist_url("enhancement"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization",
variant='',
filename="models_eynollah/eynollah-binarization_20210425",
dist=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_1",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin1",
dist=dist_url("binarization"),
type=KerasModel,
),
# early layout, light, 1-or-2-column
"region_1_2": {
#'': "modelens_12sp_elay_0_3_4__3_6_n"
#'': "modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#'': "modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#'': "modelens_1_2_4_5_early_lay_1_2_spaltige"
#'': "model_3_eraly_layout_no_patches_1_2_spaltige"
'': "modelens_e_l_all_sp_0_1_2_3_4_171024"
},
EynollahModelSpec(
category="binarization_multi_2",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin2",
dist=dist_url("binarization"),
type=KerasModel,
),
# full layout / no patches
"region_fl_np": {
#'': "modelens_full_lay_1_3_031124"
#'': "modelens_full_lay_13__3_19_241024"
#'': "model_full_lay_13_241024"
#'': "modelens_full_lay_13_17_231024"
#'': "modelens_full_lay_1_2_221024"
#'': "eynollah-full-regions-1column_20210425"
'': "modelens_full_lay_1__4_3_091124"
},
EynollahModelSpec(
category="binarization_multi_3",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin3",
dist=dist_url("binarization"),
type=KerasModel,
),
# full layout / with patches
"region_fl": {
#'': "eynollah-full-regions-3+column_20210425"
#'': #"model_2_full_layout_new_trans"
#'': "modelens_full_lay_1_3_031124"
#'': "modelens_full_lay_13__3_19_241024"
#'': "model_full_lay_13_241024"
#'': "modelens_full_lay_13_17_231024"
#'': "modelens_full_lay_1_2_221024"
#'': "modelens_full_layout_24_till_28"
#'': "model_2_full_layout_new_trans"
'': "modelens_full_lay_1__4_3_091124",
},
EynollahModelSpec(
category="binarization_multi_4",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin4",
dist=dist_url("binarization"),
type=KerasModel,
),
"reading_order": {
#'': "model_mb_ro_aug_ens_11"
#'': "model_step_3200000_mb_ro"
#'': "model_ens_reading_order_machine_based"
#'': "model_mb_ro_aug_ens_8"
#'': "model_ens_reading_order_machine_based"
'': "model_eynollah_reading_order_20250824"
},
EynollahModelSpec(
category="col_classifier",
variant='',
filename="models_eynollah/eynollah-column-classifier_20210425",
dist=dist_url("layout"),
type=KerasModel,
),
"textline": {
#'light': "eynollah-textline_light_20210425"
'light': "modelens_textline_0_1__2_4_16092024",
#'': "modelens_textline_1_4_16092024"
#'': "model_textline_ens_3_4_5_6_artificial"
#'': "modelens_textline_1_3_4_20240915"
#'': "model_textline_ens_3_4_5_6_artificial"
#'': "modelens_textline_9_12_13_14_15"
#'': "eynollah-textline_20210425"
'': "modelens_textline_0_1__2_4_16092024"
},
EynollahModelSpec(
category="page",
variant='',
filename="models_eynollah/model_eynollah_page_extraction_20250915",
dist=dist_url("layout"),
type=KerasModel,
),
"table": {
'light': "modelens_table_0t4_201124",
'': "eynollah-tables_20210319",
},
EynollahModelSpec(
category="region",
variant='',
filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
dist=dist_url("layout"),
type=KerasModel,
),
"ocr": {
'tr': "model_eynollah_ocr_trocr_20250919",
'': "model_eynollah_ocr_cnnrnn_20250930",
},
EynollahModelSpec(
category="region",
variant='extract_only_images',
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
dist=dist_url("layout"),
type=KerasModel,
),
'trocr_processor': {
'': 'microsoft/trocr-base-printed',
'htr': "microsoft/trocr-base-handwritten",
},
EynollahModelSpec(
category="region",
variant='light',
filename="models_eynollah/eynollah-main-regions_20220314",
dist=dist_url("layout"),
help="early layout",
type=KerasModel,
),
'num_to_char': {
'': 'characters_org.txt'
},
EynollahModelSpec(
category="region_p2",
variant='',
filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425",
dist=dist_url("layout"),
help="early layout, non-light, 2nd part",
type=KerasModel,
),
'characters': {
'': 'characters_org.txt'
},
EynollahModelSpec(
category="region_1_2",
variant='',
#filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n",
#filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8",
#filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18",
#filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige",
#filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige",
filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024",
dist=dist_url("layout"),
help="early layout, light, 1-or-2-column",
type=KerasModel,
),
}
EynollahModelSpec(
category="region_fl_np",
variant='',
#'filename="models_eynollah/modelens_full_lay_1_3_031124",
#'filename="models_eynollah/modelens_full_lay_13__3_19_241024",
#'filename="models_eynollah/model_full_lay_13_241024",
#'filename="models_eynollah/modelens_full_lay_13_17_231024",
#'filename="models_eynollah/modelens_full_lay_1_2_221024",
#'filename="models_eynollah/eynollah-full-regions-1column_20210425",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist=dist_url("layout"),
help="full layout / no patches",
type=KerasModel,
),
# FIXME: Why is region_fl and region_fl_np the same model?
EynollahModelSpec(
category="region_fl",
variant='',
# filename="models_eynollah/eynollah-full-regions-3+column_20210425",
# filename="models_eynollah/model_2_full_layout_new_trans",
# filename="models_eynollah/modelens_full_lay_1_3_031124",
# filename="models_eynollah/modelens_full_lay_13__3_19_241024",
# filename="models_eynollah/model_full_lay_13_241024",
# filename="models_eynollah/modelens_full_lay_13_17_231024",
# filename="models_eynollah/modelens_full_lay_1_2_221024",
# filename="models_eynollah/modelens_full_layout_24_till_28",
# filename="models_eynollah/model_2_full_layout_new_trans",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist=dist_url("layout"),
help="full layout / with patches",
type=KerasModel,
),
EynollahModelSpec(
category="reading_order",
variant='',
#filename="models_eynollah/model_mb_ro_aug_ens_11",
#filename="models_eynollah/model_step_3200000_mb_ro",
#filename="models_eynollah/model_ens_reading_order_machine_based",
#filename="models_eynollah/model_mb_ro_aug_ens_8",
#filename="models_eynollah/model_ens_reading_order_machine_based",
filename="models_eynollah/model_eynollah_reading_order_20250824",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="textline",
variant='',
#filename="models_eynollah/modelens_textline_1_4_16092024",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_1_3_4_20240915",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_9_12_13_14_15",
#filename="models_eynollah/eynollah-textline_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="textline",
variant='light',
#filename="models_eynollah/eynollah-textline_light_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="table",
variant='',
filename="models_eynollah/eynollah-tables_20210319",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="table",
variant='light',
filename="models_eynollah/modelens_table_0t4_201124",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="ocr",
variant='',
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
dist=dist_url("ocr"),
type=KerasModel,
),
EynollahModelSpec(
category="num_to_char",
variant='',
filename="models_eynollah/characters_org.txt",
dist=dist_url("ocr"),
type=KerasModel,
),
EynollahModelSpec(
category="characters",
variant='',
filename="models_eynollah/characters_org.txt",
dist=dist_url("ocr"),
type=List,
),
EynollahModelSpec(
category="ocr",
variant='tr',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist=dist_url("trocr"),
type=KerasModel,
),
EynollahModelSpec(
category="trocr_processor",
variant='',
filename="models_eynollah/microsoft/trocr-base-printed",
dist=dist_url("trocr"),
type=KerasModel,
),
EynollahModelSpec(
category="trocr_processor",
variant='htr',
filename="models_eynollah/microsoft/trocr-base-handwritten",
dist=dist_url("trocr"),
type=TrOCRProcessor,
),
])# }}}
class EynollahModelZoo():
"""
Wrapper class that handles storage and loading of models for all eynollah runners.
"""
model_basedir: Path
model_versions: dict
specs: EynollahModelSpecSet
def __init__(
self,
@ -157,10 +332,10 @@ class EynollahModelZoo():
) -> None:
self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo')
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[str, SomeEynollahModel] = {}
self._loaded: Dict[str, AnyModel] = {}
def override_models(
self,
@ -170,39 +345,24 @@ class EynollahModelZoo():
Override the default model versions
"""
for model_category, model_variant, model_filename in model_overrides:
if model_category not in DEFAULT_MODEL_VERSIONS:
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
self.logger.warning(
"Overriding default model %s ('%s' variant) from %s to %s",
model_category,
model_variant,
DEFAULT_MODEL_VERSIONS[model_category][model_variant],
model_filename
)
self.model_versions[model_category][model_variant] = model_filename
spec = self.specs.get(model_category, model_variant)
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
self.specs.get(model_category, model_variant).filename = model_filename
def model_path(
self,
model_category: str,
model_variant: str = '',
model_filename: str = '',
absolute: bool = True,
) -> Path:
"""
Translate model_{type,variant,filename} tuple into an absolute (or relative) Path
Translate model_{type,variant} tuple into an absolute (or relative) Path
"""
if model_category not in DEFAULT_MODEL_VERSIONS:
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
if not model_filename:
model_filename = DEFAULT_MODEL_VERSIONS[model_category][model_variant]
if not Path(model_filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(model_filename)
spec = self.specs.get(model_category, model_variant)
if not Path(spec.filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(spec.filename)
else:
model_path = Path(model_filename)
model_path = Path(spec.filename)
return model_path
def load_models(
@ -224,12 +384,11 @@ class EynollahModelZoo():
self,
model_category: str,
model_variant: str = '',
model_filename: str = '',
) -> SomeEynollahModel:
) -> AnyModel:
"""
Load any model
"""
model_path = self.model_path(model_category, model_variant, model_filename)
model_path = self.model_path(model_category, model_variant)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
@ -259,17 +418,17 @@ class EynollahModelZoo():
assert isinstance(ret, model_type)
return ret # type: ignore # FIXME: convince typing that we're returning generic type
def _load_ocr_model(self, variant: str) -> SomeEynollahModel:
def _load_ocr_model(self, variant: str) -> AnyModel:
"""
Load OCR model
"""
ocr_model_dir = Path(self.model_basedir, self.model_versions["ocr"][variant])
ocr_model_dir = self.model_path('ocr', variant)
if variant == 'tr':
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
else:
ocr_model = load_model(ocr_model_dir, compile=False)
assert isinstance(ocr_model, Model)
return Model(
assert isinstance(ocr_model, KerasModel)
return KerasModel(
ocr_model.get_layer(name = "image").input, # type: ignore
ocr_model.get_layer(name = "dense2").output) # type: ignore
@ -279,7 +438,7 @@ class EynollahModelZoo():
"""
with open(self.model_path('ocr') / self.model_path('num_to_char', absolute=False), "r") as config_file:
return json.load(config_file)
def _load_num_to_char(self) -> StringLookup:
"""
Load decoder for OCR
@ -295,7 +454,7 @@ class EynollahModelZoo():
def __str__(self):
return str(json.dumps({
'basedir': str(self.model_basedir),
'versions': self.model_versions,
'versions': self.specs,
}, indent=2))
def shutdown(self):