model_zoo: make type str to reduce importing overhead

This commit is contained in:
kba 2025-10-29 19:08:32 +01:00
parent a913bdf7dc
commit 5e22e9db64
6 changed files with 63 additions and 59 deletions

View file

@ -1,9 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
import os
import sys import sys
import click import click
import logging import logging
from ocrd_utils import initLogging, getLevelName, getLogger from ocrd_utils import initLogging, getLevelName, getLogger
from eynollah.model_zoo import EynollahModelZoo
from .model_zoo import EynollahModelZoo
from .cli_models import models_cli from .cli_models import models_cli
@ -11,15 +13,13 @@ from .cli_models import models_cli
class EynollahCliCtx: class EynollahCliCtx:
model_zoo: EynollahModelZoo model_zoo: EynollahModelZoo
@click.group() @click.group()
@click.option( @click.option(
"--model-basedir", "--model-basedir",
"-m", "-m",
help="directory of models", help="directory of models",
type=click.Path(exists=True, file_okay=False), type=click.Path(exists=True, file_okay=False),
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment", default=f'{os.getcwd()}/models_eynollah',
required=True,
) )
@click.option( @click.option(
"--model-overrides", "--model-overrides",

View file

@ -1,16 +1,13 @@
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Set, Tuple from typing import Set, Tuple
import click import click
from eynollah.model_zoo.default_specs import MODELS_VERSION from eynollah.model_zoo.default_specs import MODELS_VERSION
from .model_zoo import EynollahModelZoo
@click.group() @click.group()
@click.pass_context
def models_cli( def models_cli(
ctx, ctx,
model_basedir: str,
model_overrides: List[Tuple[str, str, str]],
): ):
""" """
Organize models for the various runners in eynollah. Organize models for the various runners in eynollah.
@ -26,6 +23,8 @@ def list_models(
""" """
List all the models in the zoo List all the models in the zoo
""" """
print(f"Model basedir: {ctx.obj.model_zoo.model_basedir}")
print(f"Model overrides: {ctx.obj.model_zoo.model_overrides}")
print(ctx.obj.model_zoo) print(ctx.obj.model_zoo)

View file

@ -1,5 +1,5 @@
from .specs import EynollahModelSpec, EynollahModelSpecSet from .specs import EynollahModelSpec, EynollahModelSpecSet
from .types import KerasModel, TrOCRProcessor, List from .types import KerasModel
# NOTE: This needs to change whenever models/versions change # NOTE: This needs to change whenever models/versions change
ZENODO = "https://zenodo.org/records/17295988/files" ZENODO = "https://zenodo.org/records/17295988/files"
@ -16,7 +16,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-enhancement_20210425", filename="models_eynollah/eynollah-enhancement_20210425",
dists=['enhancement', 'layout', 'ci'], dists=['enhancement', 'layout', 'ci'],
dist_url=dist_url("enhancement"), dist_url=dist_url("enhancement"),
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -25,7 +25,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens", filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens",
dists=['layout', 'binarization', ], dists=['layout', 'binarization', ],
dist_url=dist_url("binarization"), dist_url=dist_url("binarization"),
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -34,7 +34,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization_20210309", filename="models_eynollah/eynollah-binarization_20210309",
dists=['binarization'], dists=['binarization'],
dist_url=dist_url("binarization"), dist_url=dist_url("binarization"),
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -43,7 +43,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization_20210425", filename="models_eynollah/eynollah-binarization_20210425",
dists=['binarization'], dists=['binarization'],
dist_url=dist_url("binarization"), dist_url=dist_url("binarization"),
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -52,7 +52,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1", filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1",
dist_url=dist_url("binarization"), dist_url=dist_url("binarization"),
dists=['binarization'], dists=['binarization'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -61,7 +61,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2", filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2",
dist_url=dist_url("binarization"), dist_url=dist_url("binarization"),
dists=['binarization'], dists=['binarization'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -70,7 +70,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3", filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3",
dist_url=dist_url("binarization"), dist_url=dist_url("binarization"),
dists=['binarization'], dists=['binarization'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -79,7 +79,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4", filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4",
dist_url=dist_url("binarization"), dist_url=dist_url("binarization"),
dists=['binarization'], dists=['binarization'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -88,7 +88,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-column-classifier_20210425", filename="models_eynollah/eynollah-column-classifier_20210425",
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -97,7 +97,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/model_eynollah_page_extraction_20250915", filename="models_eynollah/model_eynollah_page_extraction_20250915",
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -106,7 +106,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-main-regions-ensembled_20210425", filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -115,7 +115,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -125,7 +125,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
help="early layout", help="early layout",
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -135,7 +135,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
help="early layout, non-light, 2nd part", help="early layout, non-light, 2nd part",
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -150,7 +150,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
help="early layout, light, 1-or-2-column", help="early layout, light, 1-or-2-column",
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -166,7 +166,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
help="full layout / no patches", help="full layout / no patches",
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
# FIXME: Why is region_fl and region_fl_np the same model? # FIXME: Why is region_fl and region_fl_np the same model?
@ -186,7 +186,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
help="full layout / with patches", help="full layout / with patches",
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -200,7 +200,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/model_eynollah_reading_order_20250824", filename="models_eynollah/model_eynollah_reading_order_20250824",
dist_url=dist_url("reading_order"), dist_url=dist_url("reading_order"),
dists=['layout', 'reading_order'], dists=['layout', 'reading_order'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -215,7 +215,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/modelens_textline_0_1__2_4_16092024", filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -225,7 +225,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/modelens_textline_0_1__2_4_16092024", filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -234,7 +234,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-tables_20210319", filename="models_eynollah/eynollah-tables_20210319",
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -243,7 +243,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/modelens_table_0t4_201124", filename="models_eynollah/modelens_table_0t4_201124",
dist_url=dist_url("layout"), dist_url=dist_url("layout"),
dists=['layout'], dists=['layout'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -252,7 +252,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['layout', 'ocr'], dists=['layout', 'ocr'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -262,7 +262,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
help="slightly better at degraded Fraktur", help="slightly better at degraded Fraktur",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['ocr'], dists=['ocr'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -271,7 +271,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="characters_org.txt", filename="characters_org.txt",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['ocr'], dists=['ocr'],
type=KerasModel, type='decoder',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -280,7 +280,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="characters_org.txt", filename="characters_org.txt",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['ocr'], dists=['ocr'],
type=list, type='List[str]',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -290,7 +290,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("trocr"), dist_url=dist_url("trocr"),
help='much slower transformer-based', help='much slower transformer-based',
dists=['trocr'], dists=['trocr'],
type=KerasModel, type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -299,7 +299,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/microsoft/trocr-base-printed", filename="models_eynollah/microsoft/trocr-base-printed",
dist_url=dist_url("trocr"), dist_url=dist_url("trocr"),
dists=['trocr'], dists=['trocr'],
type=KerasModel, type='TrOCRProcessor',
), ),
EynollahModelSpec( EynollahModelSpec(
@ -308,7 +308,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/microsoft/trocr-base-handwritten", filename="models_eynollah/microsoft/trocr-base-handwritten",
dist_url=dist_url("trocr"), dist_url=dist_url("trocr"),
dists=['trocr'], dists=['trocr'],
type=TrOCRProcessor, type='TrOCRProcessor',
), ),
]) ])

View file

@ -32,11 +32,18 @@ class EynollahModelZoo:
) -> None: ) -> None:
self.model_basedir = Path(basedir) self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo') self.logger = logging.getLogger('eynollah.model_zoo')
if not self.model_basedir.exists():
self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.")
self.specs = deepcopy(DEFAULT_MODEL_SPECS) self.specs = deepcopy(DEFAULT_MODEL_SPECS)
self._overrides = []
if model_overrides: if model_overrides:
self.override_models(*model_overrides) self.override_models(*model_overrides)
self._loaded: Dict[str, AnyModel] = {} self._loaded: Dict[str, AnyModel] = {}
@property
def model_overrides(self):
return self._overrides
def override_models( def override_models(
self, self,
*model_overrides: Tuple[str, str, str], *model_overrides: Tuple[str, str, str],
@ -48,6 +55,7 @@ class EynollahModelZoo:
spec = self.specs.get(model_category, model_variant) spec = self.specs.get(model_category, model_variant)
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename) self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
self.specs.get(model_category, model_variant).filename = model_filename self.specs.get(model_category, model_variant).filename = model_filename
self._overrides += model_overrides
def model_path( def model_path(
self, self,
@ -164,7 +172,7 @@ class EynollahModelZoo:
return tabulate( return tabulate(
[ [
[ [
spec.type.__name__, spec.type,
spec.category, spec.category,
spec.variant, spec.variant,
spec.help, spec.help,

View file

@ -15,7 +15,7 @@ class EynollahModelSpec():
dists: List[str] dists: List[str]
# URL to the smallest model distribution containing this model (link to Zenodo) # URL to the smallest model distribution containing this model (link to Zenodo)
dist_url: str dist_url: str
type: Type[AnyModel] type: str
variant: str = '' variant: str = ''
help: str = '' help: str = ''

View file

@ -1,19 +1,16 @@
from pathlib import Path from eynollah.model_zoo import EynollahModelZoo
from eynollah.model_zoo import EynollahModelZoo, TrOCRProcessor, VisionEncoderDecoderModel def test_trocr1(
model_dir,
testdir = Path(__file__).parent.resolve() ):
MODELS_DIR = testdir.parent model_zoo = EynollahModelZoo(model_dir)
try:
def test_trocr1(): from transformers import TrOCRProcessor, VisionEncoderDecoderModel
model_zoo = EynollahModelZoo(str(MODELS_DIR)) model_zoo.load_model('trocr_processor')
model_zoo.load_model('trocr_processor') proc = model_zoo.get('trocr_processor', TrOCRProcessor)
proc = model_zoo.get('trocr_processor', TrOCRProcessor) assert isinstance(proc, TrOCRProcessor)
assert isinstance(proc, TrOCRProcessor) model_zoo.load_model('ocr', 'tr')
model = model_zoo.get('ocr', VisionEncoderDecoderModel)
model_zoo.load_model('ocr', 'tr') assert isinstance(model, VisionEncoderDecoderModel)
model = model_zoo.get('ocr') except ImportError:
assert isinstance(model, VisionEncoderDecoderModel) pass
print(proc)
test_trocr1()