reorganize model_zoo

This commit is contained in:
kba 2025-10-22 16:04:48 +02:00
parent d94285b3ea
commit 04bc4a63d0
11 changed files with 627 additions and 507 deletions

View file

@ -6,3 +6,4 @@ tensorflow < 2.13
numba <= 0.58.1 numba <= 0.58.1
scikit-image scikit-image
biopython biopython
tabulate

View file

@ -11,46 +11,13 @@ from eynollah.image_enhancer import Enhancer
from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout
from eynollah.model_zoo import EynollahModelZoo from eynollah.model_zoo import EynollahModelZoo
@dataclass from .cli_models import models_cli
class EynollahCliCtx():
model_basedir: str
model_overrides: List[Tuple[str, str, str]]
@click.group() @click.group()
def main(): def main():
pass pass
@main.command('list-models') main.add_command(models_cli, 'models')
@click.option(
"--model",
"-m",
'model_basedir',
help="directory of models",
type=click.Path(exists=True, file_okay=False),
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
required=True,
)
@click.option(
"--model-overrides",
"-mv",
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
type=(str, str, str),
multiple=True,
)
@click.pass_context
def list_models(
ctx,
model_basedir: str,
model_overrides: List[Tuple[str, str, str]],
):
"""
List all the models in the zoo
"""
ctx.obj = EynollahCliCtx(
model_basedir=model_basedir,
model_overrides=model_overrides
)
print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides))
@main.command() @main.command()
@click.option( @click.option(
@ -143,13 +110,12 @@ def binarization(
log_level, log_level,
): ):
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
binarizer = SbbBinarizer(model_dir) binarizer = SbbBinarizer(model_dir, mode=mode)
if log_level: if log_level:
binarizer.log.setLevel(getLevelName(log_level)) binarizer.log.setLevel(getLevelName(log_level))
binarizer.run( binarizer.run(
image_path=input_image, image_path=input_image,
use_patches=patches, use_patches=patches,
mode=mode,
output=output, output=output,
dir_in=dir_in dir_in=dir_in
) )

View file

@ -0,0 +1,49 @@
from dataclasses import dataclass
from typing import List, Tuple
import click
from .model_zoo import EynollahModelZoo
@dataclass()
class EynollahCliCtx():
model_basedir: str
model_overrides: List[Tuple[str, str, str]]
@click.group()
def models_cli():
"""
Organize models for the various runners in eynollah.
"""
@models_cli.command('list')
@click.option(
"--model",
"-m",
'model_basedir',
help="directory of models",
type=click.Path(exists=True, file_okay=False),
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
required=True,
)
@click.option(
"--model-overrides",
"-mv",
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
type=(str, str, str),
multiple=True,
)
@click.pass_context
def list_models(
ctx,
model_basedir: str,
model_overrides: List[Tuple[str, str, str]],
):
"""
List all the models in the zoo
"""
ctx.obj = EynollahCliCtx(
model_basedir=model_basedir,
model_overrides=model_overrides
)
print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides))

View file

@ -18,7 +18,11 @@ from keras.models import load_model
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import numpy as np import numpy as np
from eynollah.model_zoo import EynollahModelZoo from eynollah.model_zoo import EynollahModelZoo
import torch try:
import torch
except ImportError:
torch = None
from .utils import is_image_filename from .utils import is_image_filename
from .utils.resize import resize_image from .utils.resize import resize_image

View file

@ -1,468 +0,0 @@
from copy import deepcopy
from dataclasses import dataclass
import json
import logging
from pathlib import Path
from typing import Dict, Optional, Set, Tuple, List, Type, TypeVar, Union
from keras.layers import StringLookup
from keras.models import Model as KerasModel, load_model
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from eynollah.patch_encoder import PatchEncoder, Patches
AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List]
T = TypeVar('T')
# NOTE: This needs to change whenever models change
ZENODO = "https://zenodo.org/records/17295988/files"
MODELS_VERSION = "v0_7_0"
def dist_url(dist_name: str) -> str:
return f'{ZENODO}/models_{dist_name}_${MODELS_VERSION}.zip'
@dataclass
class EynollahModelSpec():
"""
Describing a single model abstractly.
"""
category: str
# Relative filename to the models_eynollah directory in the dists
filename: str
# The smallest model distribution containing this model (link to Zenodo)
dist: str
type: Type[AnyModel]
variant: str = ''
help: str = ''
class EynollahModelSpecSet():
"""
List of all used models for eynollah.
"""
specs: List[EynollahModelSpec]
def __init__(self, specs: List[EynollahModelSpec]) -> None:
self.specs = specs
self.categories: Set[str] = set([spec.category for spec in self.specs])
self.variants: Dict[str, Set[str]] = {
spec.category: set([x.variant for x in self.specs if x.category == spec.category])
for spec in self.specs
}
self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = {
(spec.category, spec.variant): spec
for spec in self.specs
}
def asdict(self) -> Dict[str, Dict[str, str]]:
return {
spec.category: {
spec.variant: spec.filename
}
for spec in self.specs
}
def get(self, category: str, variant: str) -> EynollahModelSpec:
if category not in self.categories:
raise ValueError(f"Unknown category '{category}', must be one of {self.categories}")
if variant not in self.variants[category]:
raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}")
return self._index_category_variant[(category, variant)]
DEFAULT_MODEL_SPECS = EynollahModelSpecSet([# {{{
EynollahModelSpec(
category="enhancement",
variant='',
filename="models_eynollah/eynollah-enhancement_20210425",
dist=dist_url("enhancement"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization",
variant='',
filename="models_eynollah/eynollah-binarization_20210425",
dist=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_1",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin1",
dist=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_2",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin2",
dist=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_3",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin3",
dist=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_4",
variant='',
filename="models_eynollah/saved_model_2020_01_16/model_bin4",
dist=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="col_classifier",
variant='',
filename="models_eynollah/eynollah-column-classifier_20210425",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="page",
variant='',
filename="models_eynollah/model_eynollah_page_extraction_20250915",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="region",
variant='',
filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="region",
variant='extract_only_images',
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="region",
variant='light',
filename="models_eynollah/eynollah-main-regions_20220314",
dist=dist_url("layout"),
help="early layout",
type=KerasModel,
),
EynollahModelSpec(
category="region_p2",
variant='',
filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425",
dist=dist_url("layout"),
help="early layout, non-light, 2nd part",
type=KerasModel,
),
EynollahModelSpec(
category="region_1_2",
variant='',
#filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n",
#filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8",
#filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18",
#filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige",
#filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige",
filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024",
dist=dist_url("layout"),
help="early layout, light, 1-or-2-column",
type=KerasModel,
),
EynollahModelSpec(
category="region_fl_np",
variant='',
#'filename="models_eynollah/modelens_full_lay_1_3_031124",
#'filename="models_eynollah/modelens_full_lay_13__3_19_241024",
#'filename="models_eynollah/model_full_lay_13_241024",
#'filename="models_eynollah/modelens_full_lay_13_17_231024",
#'filename="models_eynollah/modelens_full_lay_1_2_221024",
#'filename="models_eynollah/eynollah-full-regions-1column_20210425",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist=dist_url("layout"),
help="full layout / no patches",
type=KerasModel,
),
# FIXME: Why is region_fl and region_fl_np the same model?
EynollahModelSpec(
category="region_fl",
variant='',
# filename="models_eynollah/eynollah-full-regions-3+column_20210425",
# filename="models_eynollah/model_2_full_layout_new_trans",
# filename="models_eynollah/modelens_full_lay_1_3_031124",
# filename="models_eynollah/modelens_full_lay_13__3_19_241024",
# filename="models_eynollah/model_full_lay_13_241024",
# filename="models_eynollah/modelens_full_lay_13_17_231024",
# filename="models_eynollah/modelens_full_lay_1_2_221024",
# filename="models_eynollah/modelens_full_layout_24_till_28",
# filename="models_eynollah/model_2_full_layout_new_trans",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist=dist_url("layout"),
help="full layout / with patches",
type=KerasModel,
),
EynollahModelSpec(
category="reading_order",
variant='',
#filename="models_eynollah/model_mb_ro_aug_ens_11",
#filename="models_eynollah/model_step_3200000_mb_ro",
#filename="models_eynollah/model_ens_reading_order_machine_based",
#filename="models_eynollah/model_mb_ro_aug_ens_8",
#filename="models_eynollah/model_ens_reading_order_machine_based",
filename="models_eynollah/model_eynollah_reading_order_20250824",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="textline",
variant='',
#filename="models_eynollah/modelens_textline_1_4_16092024",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_1_3_4_20240915",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_9_12_13_14_15",
#filename="models_eynollah/eynollah-textline_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="textline",
variant='light',
#filename="models_eynollah/eynollah-textline_light_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="table",
variant='',
filename="models_eynollah/eynollah-tables_20210319",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="table",
variant='light',
filename="models_eynollah/modelens_table_0t4_201124",
dist=dist_url("layout"),
type=KerasModel,
),
EynollahModelSpec(
category="ocr",
variant='',
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
dist=dist_url("ocr"),
type=KerasModel,
),
EynollahModelSpec(
category="num_to_char",
variant='',
filename="models_eynollah/characters_org.txt",
dist=dist_url("ocr"),
type=KerasModel,
),
EynollahModelSpec(
category="characters",
variant='',
filename="models_eynollah/characters_org.txt",
dist=dist_url("ocr"),
type=List,
),
EynollahModelSpec(
category="ocr",
variant='tr',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist=dist_url("trocr"),
type=KerasModel,
),
EynollahModelSpec(
category="trocr_processor",
variant='',
filename="models_eynollah/microsoft/trocr-base-printed",
dist=dist_url("trocr"),
type=KerasModel,
),
EynollahModelSpec(
category="trocr_processor",
variant='htr',
filename="models_eynollah/microsoft/trocr-base-handwritten",
dist=dist_url("trocr"),
type=TrOCRProcessor,
),
])# }}}
class EynollahModelZoo():
"""
Wrapper class that handles storage and loading of models for all eynollah runners.
"""
model_basedir: Path
specs: EynollahModelSpecSet
def __init__(
self,
basedir: str,
model_overrides: Optional[List[Tuple[str, str, str]]]=None,
) -> None:
self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo')
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[str, AnyModel] = {}
def override_models(
self,
*model_overrides: Tuple[str, str, str],
):
"""
Override the default model versions
"""
for model_category, model_variant, model_filename in model_overrides:
spec = self.specs.get(model_category, model_variant)
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
self.specs.get(model_category, model_variant).filename = model_filename
def model_path(
self,
model_category: str,
model_variant: str = '',
absolute: bool = True,
) -> Path:
"""
Translate model_{type,variant} tuple into an absolute (or relative) Path
"""
spec = self.specs.get(model_category, model_variant)
if not Path(spec.filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(spec.filename)
else:
model_path = Path(spec.filename)
return model_path
def load_models(
self,
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
) -> Dict:
"""
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
"""
ret = {}
for load_args in all_load_args:
if isinstance(load_args, str):
ret[load_args] = self.load_model(load_args)
else:
ret[load_args[0]] = self.load_model(*load_args)
return ret
def load_model(
self,
model_category: str,
model_variant: str = '',
) -> AnyModel:
"""
Load any model
"""
model_path = self.model_path(model_category, model_variant)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
if model_category == 'ocr':
model = self._load_ocr_model(variant=model_variant)
elif model_category == 'num_to_char':
model = self._load_num_to_char()
elif model_category == 'characters':
model = self._load_characters()
elif model_category == 'trocr_processor':
return TrOCRProcessor.from_pretrained(self.model_path(...))
else:
try:
model = load_model(model_path, compile=False)
except Exception as e:
self.logger.exception(e)
model = load_model(model_path, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
self._loaded[model_category] = model
return model # type: ignore
def get(self, model_category: str, model_type: Optional[Type[T]]=None) -> T:
if model_category not in self._loaded:
raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"')
ret = self._loaded[model_category]
if model_type:
assert isinstance(ret, model_type)
return ret # type: ignore # FIXME: convince typing that we're returning generic type
def _load_ocr_model(self, variant: str) -> AnyModel:
"""
Load OCR model
"""
ocr_model_dir = self.model_path('ocr', variant)
if variant == 'tr':
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
else:
ocr_model = load_model(ocr_model_dir, compile=False)
assert isinstance(ocr_model, KerasModel)
return KerasModel(
ocr_model.get_layer(name = "image").input, # type: ignore
ocr_model.get_layer(name = "dense2").output) # type: ignore
def _load_characters(self) -> List[str]:
"""
Load encoding for OCR
"""
with open(self.model_path('ocr') / self.model_path('num_to_char', absolute=False), "r") as config_file:
return json.load(config_file)
def _load_num_to_char(self) -> StringLookup:
"""
Load decoder for OCR
"""
characters = self._load_characters()
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
# Mapping integers back to original characters.
return StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
def __str__(self):
return str(json.dumps({
'basedir': str(self.model_basedir),
'versions': self.specs,
}, indent=2))
def shutdown(self):
"""
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
"""
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in self._loaded:
if self._loaded[needle]:
del self._loaded[needle]

View file

@ -0,0 +1,4 @@
__all__ = [
'EynollahModelZoo',
]
from .model_zoo import EynollahModelZoo

View file

@ -0,0 +1,314 @@
from .specs import EynollahModelSpec, EynollahModelSpecSet
from .types import KerasModel, TrOCRProcessor, List
# NOTE: This needs to change whenever models/versions change
ZENODO = "https://zenodo.org/records/17295988/files"
MODELS_VERSION = "v0_7_0"
def dist_url(dist_name: str) -> str:
return f'{ZENODO}/models_{dist_name}_{MODELS_VERSION}.zip'
DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
EynollahModelSpec(
category="enhancement",
variant='',
filename="models_eynollah/eynollah-enhancement_20210425",
dists=['enhancement', 'layout'],
dist_url=dist_url("enhancement"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization",
variant='',
filename="models_eynollah/eynollah-binarization-hybrid_20230504",
dists=['layout', 'binarization'],
dist_url=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization",
variant='20210309',
filename="models_eynollah/eynollah-binarization_20210309",
dists=['binarization'],
dist_url=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization",
variant='augment',
filename="models_eynollah/eynollah-binarization_20210425",
dists=['binarization'],
dist_url=dist_url("binarization"),
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_1",
variant='',
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1",
dist_url=dist_url("binarization"),
dists=['binarization'],
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_2",
variant='',
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2",
dist_url=dist_url("binarization"),
dists=['binarization'],
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_3",
variant='',
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3",
dist_url=dist_url("binarization"),
dists=['binarization'],
type=KerasModel,
),
EynollahModelSpec(
category="binarization_multi_4",
variant='',
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4",
dist_url=dist_url("binarization"),
dists=['binarization'],
type=KerasModel,
),
EynollahModelSpec(
category="col_classifier",
variant='',
filename="models_eynollah/eynollah-column-classifier_20210425",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="page",
variant='',
filename="models_eynollah/model_eynollah_page_extraction_20250915",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="region",
variant='',
filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="region",
variant='extract_only_images',
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="region",
variant='light',
filename="models_eynollah/eynollah-main-regions_20220314",
dist_url=dist_url("layout"),
help="early layout",
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="region_p2",
variant='',
filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425",
dist_url=dist_url("layout"),
help="early layout, non-light, 2nd part",
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="region_1_2",
variant='',
#filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n",
#filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8",
#filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18",
#filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige",
#filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige",
filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024",
dist_url=dist_url("layout"),
dists=['layout'],
help="early layout, light, 1-or-2-column",
type=KerasModel,
),
EynollahModelSpec(
category="region_fl_np",
variant='',
#'filename="models_eynollah/modelens_full_lay_1_3_031124",
#'filename="models_eynollah/modelens_full_lay_13__3_19_241024",
#'filename="models_eynollah/model_full_lay_13_241024",
#'filename="models_eynollah/modelens_full_lay_13_17_231024",
#'filename="models_eynollah/modelens_full_lay_1_2_221024",
#'filename="models_eynollah/eynollah-full-regions-1column_20210425",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist_url=dist_url("layout"),
help="full layout / no patches",
dists=['layout'],
type=KerasModel,
),
# FIXME: Why is region_fl and region_fl_np the same model?
EynollahModelSpec(
category="region_fl",
variant='',
# filename="models_eynollah/eynollah-full-regions-3+column_20210425",
# filename="models_eynollah/model_2_full_layout_new_trans",
# filename="models_eynollah/modelens_full_lay_1_3_031124",
# filename="models_eynollah/modelens_full_lay_13__3_19_241024",
# filename="models_eynollah/model_full_lay_13_241024",
# filename="models_eynollah/modelens_full_lay_13_17_231024",
# filename="models_eynollah/modelens_full_lay_1_2_221024",
# filename="models_eynollah/modelens_full_layout_24_till_28",
# filename="models_eynollah/model_2_full_layout_new_trans",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist_url=dist_url("layout"),
help="full layout / with patches",
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="reading_order",
variant='',
#filename="models_eynollah/model_mb_ro_aug_ens_11",
#filename="models_eynollah/model_step_3200000_mb_ro",
#filename="models_eynollah/model_ens_reading_order_machine_based",
#filename="models_eynollah/model_mb_ro_aug_ens_8",
#filename="models_eynollah/model_ens_reading_order_machine_based",
filename="models_eynollah/model_eynollah_reading_order_20250824",
dist_url=dist_url("reading_order"),
dists=['layout', 'reading_order'],
type=KerasModel,
),
EynollahModelSpec(
category="textline",
variant='',
#filename="models_eynollah/modelens_textline_1_4_16092024",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_1_3_4_20240915",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_9_12_13_14_15",
#filename="models_eynollah/eynollah-textline_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="textline",
variant='light',
#filename="models_eynollah/eynollah-textline_light_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="table",
variant='',
filename="models_eynollah/eynollah-tables_20210319",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="table",
variant='light',
filename="models_eynollah/modelens_table_0t4_201124",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
),
EynollahModelSpec(
category="ocr",
variant='',
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
dist_url=dist_url("ocr"),
dists=['layout', 'ocr'],
type=KerasModel,
),
EynollahModelSpec(
category="ocr",
variant='degraded',
filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/",
help="slightly better at degraded Fraktur",
dist_url=dist_url("ocr"),
dists=['ocr'],
type=KerasModel,
),
EynollahModelSpec(
category="num_to_char",
variant='',
filename="characters_org.txt",
dist_url=dist_url("ocr"),
dists=['ocr'],
type=KerasModel,
),
EynollahModelSpec(
category="characters",
variant='',
filename="characters_org.txt",
dist_url=dist_url("ocr"),
dists=['ocr'],
type=list,
),
EynollahModelSpec(
category="ocr",
variant='tr',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist_url=dist_url("trocr"),
help='much slower transformer-based',
dists=['trocr'],
type=KerasModel,
),
EynollahModelSpec(
category="trocr_processor",
variant='',
filename="models_eynollah/microsoft/trocr-base-printed",
dist_url=dist_url("trocr"),
dists=['trocr'],
type=KerasModel,
),
EynollahModelSpec(
category="trocr_processor",
variant='htr',
filename="models_eynollah/microsoft/trocr-base-handwritten",
dist_url=dist_url("trocr"),
dists=['trocr'],
type=TrOCRProcessor,
),
])

View file

@ -0,0 +1,189 @@
import json
import logging
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union
from keras.layers import StringLookup
from keras.models import Model as KerasModel
from keras.models import load_model
from tabulate import tabulate
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ..patch_encoder import PatchEncoder, Patches
from .specs import EynollahModelSpecSet
from .default_specs import DEFAULT_MODEL_SPECS
from .types import AnyModel, T
class EynollahModelZoo:
"""
Wrapper class that handles storage and loading of models for all eynollah runners.
"""
model_basedir: Path
specs: EynollahModelSpecSet
def __init__(
self,
basedir: str,
model_overrides: Optional[List[Tuple[str, str, str]]] = None,
) -> None:
self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo')
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[str, AnyModel] = {}
def override_models(
self,
*model_overrides: Tuple[str, str, str],
):
"""
Override the default model versions
"""
for model_category, model_variant, model_filename in model_overrides:
spec = self.specs.get(model_category, model_variant)
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
self.specs.get(model_category, model_variant).filename = model_filename
def model_path(
self,
model_category: str,
model_variant: str = '',
absolute: bool = True,
) -> Path:
"""
Translate model_{type,variant} tuple into an absolute (or relative) Path
"""
spec = self.specs.get(model_category, model_variant)
if spec.category in ('characters', 'num_to_char'):
return self.model_path('ocr') / spec.filename
if not Path(spec.filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(spec.filename)
else:
model_path = Path(spec.filename)
return model_path
def load_models(
self,
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
) -> Dict:
"""
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
"""
ret = {}
for load_args in all_load_args:
if isinstance(load_args, str):
ret[load_args] = self.load_model(load_args)
else:
ret[load_args[0]] = self.load_model(*load_args)
return ret
def load_model(
self,
model_category: str,
model_variant: str = '',
) -> AnyModel:
"""
Load any model
"""
model_path = self.model_path(model_category, model_variant)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
if model_category == 'ocr':
model = self._load_ocr_model(variant=model_variant)
elif model_category == 'num_to_char':
model = self._load_num_to_char()
elif model_category == 'characters':
model = self._load_characters()
elif model_category == 'trocr_processor':
return TrOCRProcessor.from_pretrained(self.model_path(...))
else:
try:
model = load_model(model_path, compile=False)
except Exception as e:
self.logger.exception(e)
model = load_model(
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
)
self._loaded[model_category] = model
return model # type: ignore
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T:
if model_category not in self._loaded:
raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"')
ret = self._loaded[model_category]
if model_type:
assert isinstance(ret, model_type)
return ret # type: ignore # FIXME: convince typing that we're returning generic type
def _load_ocr_model(self, variant: str) -> AnyModel:
"""
Load OCR model
"""
ocr_model_dir = self.model_path('ocr', variant)
if variant == 'tr':
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
else:
ocr_model = load_model(ocr_model_dir, compile=False)
assert isinstance(ocr_model, KerasModel)
return KerasModel(
ocr_model.get_layer(name="image").input, # type: ignore
ocr_model.get_layer(name="dense2").output, # type: ignore
)
def _load_characters(self) -> List[str]:
"""
Load encoding for OCR
"""
with open(self.model_path('num_to_char'), "r") as config_file:
return json.load(config_file)
def _load_num_to_char(self) -> StringLookup:
"""
Load decoder for OCR
"""
characters = self._load_characters()
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
# Mapping integers back to original characters.
return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
def __str__(self):
return tabulate(
[
[
spec.type.__name__,
spec.category,
spec.variant,
spec.help,
', '.join(spec.dists),
f'Yes, at {self.model_path(spec.category, spec.variant)}'
if self.model_path(spec.category, spec.variant).exists()
else f'No, download {spec.dist_url}',
# self.model_path(spec.category, spec.variant),
]
for spec in sorted(self.specs.specs, key=lambda x: x.category + '0' + x.variant)
],
headers=[
'Type',
'Category',
'Variant',
'Help',
'Used in',
'Installed',
],
tablefmt='github',
)
def shutdown(self):
"""
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
"""
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in self._loaded:
if self._loaded[needle]:
del self._loaded[needle]

View file

@ -0,0 +1,55 @@
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple, Type
from .types import AnyModel
@dataclass
class EynollahModelSpec():
"""
Describing a single model abstractly.
"""
category: str
# Relative filename to the models_eynollah directory in the dists
filename: str
# basename of the ZIP files that should contain this model
dists: List[str]
# URL to the smallest model distribution containing this model (link to Zenodo)
dist_url: str
type: Type[AnyModel]
variant: str = ''
help: str = ''
class EynollahModelSpecSet():
"""
List of all used models for eynollah.
"""
specs: List[EynollahModelSpec]
def __init__(self, specs: List[EynollahModelSpec]) -> None:
self.specs = specs
self.categories: Set[str] = set([spec.category for spec in self.specs])
self.variants: Dict[str, Set[str]] = {
spec.category: set([x.variant for x in self.specs if x.category == spec.category])
for spec in self.specs
}
self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = {
(spec.category, spec.variant): spec
for spec in self.specs
}
def asdict(self) -> Dict[str, Dict[str, str]]:
return {
spec.category: {
spec.variant: spec.filename
}
for spec in self.specs
}
def get(self, category: str, variant: str) -> EynollahModelSpec:
if category not in self.categories:
raise ValueError(f"Unknown category '{category}', must be one of {self.categories}")
if variant not in self.variants[category]:
raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}")
return self._index_category_variant[(category, variant)]

View file

@ -0,0 +1,6 @@
from typing import List, TypeVar, Union
from keras.models import Model as KerasModel
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List]
T = TypeVar('T')

View file

@ -24,7 +24,7 @@ def resize_image(img_in, input_height, input_width):
class SbbBinarizer: class SbbBinarizer:
def __init__(self, model_dir, mode='single', logger=None): def __init__(self, model_dir: str, mode: str, logger=None):
if mode not in ('single', 'multi'): if mode not in ('single', 'multi'):
raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}") raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}")
self.log = logger if logger else logging.getLogger('SbbBinarizer') self.log = logger if logger else logging.getLogger('SbbBinarizer')