diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index 9ae909f..bd2d807 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -1,9 +1,11 @@ from dataclasses import dataclass +import os import sys import click import logging from ocrd_utils import initLogging, getLevelName, getLogger -from eynollah.model_zoo import EynollahModelZoo + +from .model_zoo import EynollahModelZoo from .cli_models import models_cli @@ -11,15 +13,13 @@ from .cli_models import models_cli class EynollahCliCtx: model_zoo: EynollahModelZoo - @click.group() @click.option( "--model-basedir", "-m", help="directory of models", type=click.Path(exists=True, file_okay=False), - # default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment", - required=True, + default=f'{os.getcwd()}/models_eynollah', ) @click.option( "--model-overrides", diff --git a/src/eynollah/cli_models.py b/src/eynollah/cli_models.py index 2f6eded..f3de596 100644 --- a/src/eynollah/cli_models.py +++ b/src/eynollah/cli_models.py @@ -1,16 +1,13 @@ -from dataclasses import dataclass from pathlib import Path -from typing import List, Set, Tuple +from typing import Set, Tuple import click from eynollah.model_zoo.default_specs import MODELS_VERSION -from .model_zoo import EynollahModelZoo @click.group() +@click.pass_context def models_cli( ctx, - model_basedir: str, - model_overrides: List[Tuple[str, str, str]], ): """ Organize models for the various runners in eynollah. @@ -26,6 +23,8 @@ def list_models( """ List all the models in the zoo """ + print(f"Model basedir: {ctx.obj.model_zoo.model_basedir}") + print(f"Model overrides: {ctx.obj.model_zoo.model_overrides}") print(ctx.obj.model_zoo) diff --git a/src/eynollah/model_zoo/default_specs.py b/src/eynollah/model_zoo/default_specs.py index fa67393..8daa270 100644 --- a/src/eynollah/model_zoo/default_specs.py +++ b/src/eynollah/model_zoo/default_specs.py @@ -1,5 +1,5 @@ from .specs import EynollahModelSpec, EynollahModelSpecSet -from .types import KerasModel, TrOCRProcessor, List +from .types import KerasModel # NOTE: This needs to change whenever models/versions change ZENODO = "https://zenodo.org/records/17295988/files" @@ -16,7 +16,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-enhancement_20210425", dists=['enhancement', 'layout', 'ci'], dist_url=dist_url("enhancement"), - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -25,7 +25,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens", dists=['layout', 'binarization', ], dist_url=dist_url("binarization"), - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -34,7 +34,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-binarization_20210309", dists=['binarization'], dist_url=dist_url("binarization"), - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -43,7 +43,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-binarization_20210425", dists=['binarization'], dist_url=dist_url("binarization"), - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -52,7 +52,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1", dist_url=dist_url("binarization"), dists=['binarization'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -61,7 +61,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2", dist_url=dist_url("binarization"), dists=['binarization'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -70,7 +70,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3", dist_url=dist_url("binarization"), dists=['binarization'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -79,7 +79,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4", dist_url=dist_url("binarization"), dists=['binarization'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -88,7 +88,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-column-classifier_20210425", dist_url=dist_url("layout"), dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -97,7 +97,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/model_eynollah_page_extraction_20250915", dist_url=dist_url("layout"), dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -106,7 +106,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-main-regions-ensembled_20210425", dist_url=dist_url("layout"), dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -115,7 +115,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", dist_url=dist_url("layout"), dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -125,7 +125,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ dist_url=dist_url("layout"), help="early layout", dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -135,7 +135,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ dist_url=dist_url("layout"), help="early layout, non-light, 2nd part", dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -150,7 +150,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ dist_url=dist_url("layout"), dists=['layout'], help="early layout, light, 1-or-2-column", - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -166,7 +166,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ dist_url=dist_url("layout"), help="full layout / no patches", dists=['layout'], - type=KerasModel, + type='Keras', ), # FIXME: Why is region_fl and region_fl_np the same model? @@ -186,7 +186,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ dist_url=dist_url("layout"), help="full layout / with patches", dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -200,7 +200,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/model_eynollah_reading_order_20250824", dist_url=dist_url("reading_order"), dists=['layout', 'reading_order'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -215,7 +215,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/modelens_textline_0_1__2_4_16092024", dist_url=dist_url("layout"), dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -225,7 +225,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/modelens_textline_0_1__2_4_16092024", dist_url=dist_url("layout"), dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -234,7 +234,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/eynollah-tables_20210319", dist_url=dist_url("layout"), dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -243,7 +243,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/modelens_table_0t4_201124", dist_url=dist_url("layout"), dists=['layout'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -252,7 +252,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", dist_url=dist_url("ocr"), dists=['layout', 'ocr'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -262,7 +262,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ help="slightly better at degraded Fraktur", dist_url=dist_url("ocr"), dists=['ocr'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -271,7 +271,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="characters_org.txt", dist_url=dist_url("ocr"), dists=['ocr'], - type=KerasModel, + type='decoder', ), EynollahModelSpec( @@ -280,7 +280,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="characters_org.txt", dist_url=dist_url("ocr"), dists=['ocr'], - type=list, + type='List[str]', ), EynollahModelSpec( @@ -290,7 +290,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ dist_url=dist_url("trocr"), help='much slower transformer-based', dists=['trocr'], - type=KerasModel, + type='Keras', ), EynollahModelSpec( @@ -299,7 +299,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/microsoft/trocr-base-printed", dist_url=dist_url("trocr"), dists=['trocr'], - type=KerasModel, + type='TrOCRProcessor', ), EynollahModelSpec( @@ -308,7 +308,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ filename="models_eynollah/microsoft/trocr-base-handwritten", dist_url=dist_url("trocr"), dists=['trocr'], - type=TrOCRProcessor, + type='TrOCRProcessor', ), ]) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 40e979f..512bf1a 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -32,11 +32,18 @@ class EynollahModelZoo: ) -> None: self.model_basedir = Path(basedir) self.logger = logging.getLogger('eynollah.model_zoo') + if not self.model_basedir.exists(): + self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.") self.specs = deepcopy(DEFAULT_MODEL_SPECS) + self._overrides = [] if model_overrides: self.override_models(*model_overrides) self._loaded: Dict[str, AnyModel] = {} + @property + def model_overrides(self): + return self._overrides + def override_models( self, *model_overrides: Tuple[str, str, str], @@ -48,6 +55,7 @@ class EynollahModelZoo: spec = self.specs.get(model_category, model_variant) self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename) self.specs.get(model_category, model_variant).filename = model_filename + self._overrides += model_overrides def model_path( self, @@ -164,7 +172,7 @@ class EynollahModelZoo: return tabulate( [ [ - spec.type.__name__, + spec.type, spec.category, spec.variant, spec.help, diff --git a/src/eynollah/model_zoo/specs.py b/src/eynollah/model_zoo/specs.py index 322afa4..415e55d 100644 --- a/src/eynollah/model_zoo/specs.py +++ b/src/eynollah/model_zoo/specs.py @@ -15,7 +15,7 @@ class EynollahModelSpec(): dists: List[str] # URL to the smallest model distribution containing this model (link to Zenodo) dist_url: str - type: Type[AnyModel] + type: str variant: str = '' help: str = '' diff --git a/tests/test_model_zoo.py b/tests/test_model_zoo.py index 81e84f6..2042b28 100644 --- a/tests/test_model_zoo.py +++ b/tests/test_model_zoo.py @@ -1,19 +1,16 @@ -from pathlib import Path +from eynollah.model_zoo import EynollahModelZoo -from eynollah.model_zoo import EynollahModelZoo, TrOCRProcessor, VisionEncoderDecoderModel - -testdir = Path(__file__).parent.resolve() -MODELS_DIR = testdir.parent - -def test_trocr1(): - model_zoo = EynollahModelZoo(str(MODELS_DIR)) - model_zoo.load_model('trocr_processor') - proc = model_zoo.get('trocr_processor', TrOCRProcessor) - assert isinstance(proc, TrOCRProcessor) - - model_zoo.load_model('ocr', 'tr') - model = model_zoo.get('ocr') - assert isinstance(model, VisionEncoderDecoderModel) - print(proc) - -test_trocr1() +def test_trocr1( + model_dir, +): + model_zoo = EynollahModelZoo(model_dir) + try: + from transformers import TrOCRProcessor, VisionEncoderDecoderModel + model_zoo.load_model('trocr_processor') + proc = model_zoo.get('trocr_processor', TrOCRProcessor) + assert isinstance(proc, TrOCRProcessor) + model_zoo.load_model('ocr', 'tr') + model = model_zoo.get('ocr', VisionEncoderDecoderModel) + assert isinstance(model, VisionEncoderDecoderModel) + except ImportError: + pass