model_zoo: make type str to reduce importing overhead

This commit is contained in:
kba 2025-10-29 19:08:32 +01:00
parent a913bdf7dc
commit 5e22e9db64
6 changed files with 63 additions and 59 deletions

View file

@ -1,9 +1,11 @@
from dataclasses import dataclass
import os
import sys
import click
import logging
from ocrd_utils import initLogging, getLevelName, getLogger
from eynollah.model_zoo import EynollahModelZoo
from .model_zoo import EynollahModelZoo
from .cli_models import models_cli
@ -11,15 +13,13 @@ from .cli_models import models_cli
class EynollahCliCtx:
model_zoo: EynollahModelZoo
@click.group()
@click.option(
"--model-basedir",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
required=True,
default=f'{os.getcwd()}/models_eynollah',
)
@click.option(
"--model-overrides",

View file

@ -1,16 +1,13 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List, Set, Tuple
from typing import Set, Tuple
import click
from eynollah.model_zoo.default_specs import MODELS_VERSION
from .model_zoo import EynollahModelZoo
@click.group()
@click.pass_context
def models_cli(
ctx,
model_basedir: str,
model_overrides: List[Tuple[str, str, str]],
):
"""
Organize models for the various runners in eynollah.
@ -26,6 +23,8 @@ def list_models(
"""
List all the models in the zoo
"""
print(f"Model basedir: {ctx.obj.model_zoo.model_basedir}")
print(f"Model overrides: {ctx.obj.model_zoo.model_overrides}")
print(ctx.obj.model_zoo)

View file

@ -1,5 +1,5 @@
from .specs import EynollahModelSpec, EynollahModelSpecSet
from .types import KerasModel, TrOCRProcessor, List
from .types import KerasModel
# NOTE: This needs to change whenever models/versions change
ZENODO = "https://zenodo.org/records/17295988/files"
@ -16,7 +16,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-enhancement_20210425",
dists=['enhancement', 'layout', 'ci'],
dist_url=dist_url("enhancement"),
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -25,7 +25,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens",
dists=['layout', 'binarization', ],
dist_url=dist_url("binarization"),
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -34,7 +34,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization_20210309",
dists=['binarization'],
dist_url=dist_url("binarization"),
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -43,7 +43,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization_20210425",
dists=['binarization'],
dist_url=dist_url("binarization"),
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -52,7 +52,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1",
dist_url=dist_url("binarization"),
dists=['binarization'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -61,7 +61,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2",
dist_url=dist_url("binarization"),
dists=['binarization'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -70,7 +70,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3",
dist_url=dist_url("binarization"),
dists=['binarization'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -79,7 +79,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4",
dist_url=dist_url("binarization"),
dists=['binarization'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -88,7 +88,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-column-classifier_20210425",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -97,7 +97,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/model_eynollah_page_extraction_20250915",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -106,7 +106,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -115,7 +115,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -125,7 +125,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"),
help="early layout",
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -135,7 +135,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"),
help="early layout, non-light, 2nd part",
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -150,7 +150,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"),
dists=['layout'],
help="early layout, light, 1-or-2-column",
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -166,7 +166,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"),
help="full layout / no patches",
dists=['layout'],
type=KerasModel,
type='Keras',
),
# FIXME: Why is region_fl and region_fl_np the same model?
@ -186,7 +186,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("layout"),
help="full layout / with patches",
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -200,7 +200,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/model_eynollah_reading_order_20250824",
dist_url=dist_url("reading_order"),
dists=['layout', 'reading_order'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -215,7 +215,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -225,7 +225,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -234,7 +234,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/eynollah-tables_20210319",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -243,7 +243,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/modelens_table_0t4_201124",
dist_url=dist_url("layout"),
dists=['layout'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -252,7 +252,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
dist_url=dist_url("ocr"),
dists=['layout', 'ocr'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -262,7 +262,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
help="slightly better at degraded Fraktur",
dist_url=dist_url("ocr"),
dists=['ocr'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -271,7 +271,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="characters_org.txt",
dist_url=dist_url("ocr"),
dists=['ocr'],
type=KerasModel,
type='decoder',
),
EynollahModelSpec(
@ -280,7 +280,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="characters_org.txt",
dist_url=dist_url("ocr"),
dists=['ocr'],
type=list,
type='List[str]',
),
EynollahModelSpec(
@ -290,7 +290,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
dist_url=dist_url("trocr"),
help='much slower transformer-based',
dists=['trocr'],
type=KerasModel,
type='Keras',
),
EynollahModelSpec(
@ -299,7 +299,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/microsoft/trocr-base-printed",
dist_url=dist_url("trocr"),
dists=['trocr'],
type=KerasModel,
type='TrOCRProcessor',
),
EynollahModelSpec(
@ -308,7 +308,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/microsoft/trocr-base-handwritten",
dist_url=dist_url("trocr"),
dists=['trocr'],
type=TrOCRProcessor,
type='TrOCRProcessor',
),
])

View file

@ -32,11 +32,18 @@ class EynollahModelZoo:
) -> None:
self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo')
if not self.model_basedir.exists():
self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.")
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
self._overrides = []
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[str, AnyModel] = {}
@property
def model_overrides(self):
return self._overrides
def override_models(
self,
*model_overrides: Tuple[str, str, str],
@ -48,6 +55,7 @@ class EynollahModelZoo:
spec = self.specs.get(model_category, model_variant)
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
self.specs.get(model_category, model_variant).filename = model_filename
self._overrides += model_overrides
def model_path(
self,
@ -164,7 +172,7 @@ class EynollahModelZoo:
return tabulate(
[
[
spec.type.__name__,
spec.type,
spec.category,
spec.variant,
spec.help,

View file

@ -15,7 +15,7 @@ class EynollahModelSpec():
dists: List[str]
# URL to the smallest model distribution containing this model (link to Zenodo)
dist_url: str
type: Type[AnyModel]
type: str
variant: str = ''
help: str = ''

View file

@ -1,19 +1,16 @@
from pathlib import Path
from eynollah.model_zoo import EynollahModelZoo
from eynollah.model_zoo import EynollahModelZoo, TrOCRProcessor, VisionEncoderDecoderModel
testdir = Path(__file__).parent.resolve()
MODELS_DIR = testdir.parent
def test_trocr1():
model_zoo = EynollahModelZoo(str(MODELS_DIR))
model_zoo.load_model('trocr_processor')
proc = model_zoo.get('trocr_processor', TrOCRProcessor)
assert isinstance(proc, TrOCRProcessor)
model_zoo.load_model('ocr', 'tr')
model = model_zoo.get('ocr')
assert isinstance(model, VisionEncoderDecoderModel)
print(proc)
test_trocr1()
def test_trocr1(
model_dir,
):
model_zoo = EynollahModelZoo(model_dir)
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
model_zoo.load_model('trocr_processor')
proc = model_zoo.get('trocr_processor', TrOCRProcessor)
assert isinstance(proc, TrOCRProcessor)
model_zoo.load_model('ocr', 'tr')
model = model_zoo.get('ocr', VisionEncoderDecoderModel)
assert isinstance(model, VisionEncoderDecoderModel)
except ImportError:
pass