mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-27 07:44:12 +01:00
rewrite model spec data structure
This commit is contained in:
parent
146658f026
commit
d94285b3ea
1 changed files with 306 additions and 147 deletions
|
|
@ -1,154 +1,329 @@
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Literal, Optional, Tuple, List, Type, TypeVar, Union
|
from typing import Dict, Optional, Set, Tuple, List, Type, TypeVar, Union
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
from keras.layers import StringLookup
|
from keras.layers import StringLookup
|
||||||
from keras.models import Model, load_model
|
from keras.models import Model as KerasModel, load_model
|
||||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||||
|
|
||||||
from eynollah.patch_encoder import PatchEncoder, Patches
|
from eynollah.patch_encoder import PatchEncoder, Patches
|
||||||
|
|
||||||
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model, List]
|
AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List]
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
# Dict mapping model_category to dict mapping variant (default is '') to Path
|
# NOTE: This needs to change whenever models change
|
||||||
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
|
ZENODO = "https://zenodo.org/records/17295988/files"
|
||||||
|
MODELS_VERSION = "v0_7_0"
|
||||||
|
|
||||||
"enhancement": {
|
def dist_url(dist_name: str) -> str:
|
||||||
'': "eynollah-enhancement_20210425"
|
return f'{ZENODO}/models_{dist_name}_${MODELS_VERSION}.zip'
|
||||||
},
|
|
||||||
|
|
||||||
"binarization": {
|
@dataclass
|
||||||
'': "eynollah-binarization_20210425"
|
class EynollahModelSpec():
|
||||||
},
|
"""
|
||||||
|
Describing a single model abstractly.
|
||||||
|
"""
|
||||||
|
category: str
|
||||||
|
# Relative filename to the models_eynollah directory in the dists
|
||||||
|
filename: str
|
||||||
|
# The smallest model distribution containing this model (link to Zenodo)
|
||||||
|
dist: str
|
||||||
|
type: Type[AnyModel]
|
||||||
|
variant: str = ''
|
||||||
|
help: str = ''
|
||||||
|
|
||||||
"binarization_multi_1": {
|
class EynollahModelSpecSet():
|
||||||
'': "saved_model_2020_01_16/model_bin1",
|
"""
|
||||||
},
|
List of all used models for eynollah.
|
||||||
"binarization_multi_2": {
|
"""
|
||||||
'': "saved_model_2020_01_16/model_bin2",
|
specs: List[EynollahModelSpec]
|
||||||
},
|
|
||||||
"binarization_multi_3": {
|
|
||||||
'': "saved_model_2020_01_16/model_bin3",
|
|
||||||
},
|
|
||||||
"binarization_multi_4": {
|
|
||||||
'': "saved_model_2020_01_16/model_bin4",
|
|
||||||
},
|
|
||||||
|
|
||||||
"col_classifier": {
|
def __init__(self, specs: List[EynollahModelSpec]) -> None:
|
||||||
'': "eynollah-column-classifier_20210425",
|
self.specs = specs
|
||||||
},
|
self.categories: Set[str] = set([spec.category for spec in self.specs])
|
||||||
|
self.variants: Dict[str, Set[str]] = {
|
||||||
|
spec.category: set([x.variant for x in self.specs if x.category == spec.category])
|
||||||
|
for spec in self.specs
|
||||||
|
}
|
||||||
|
self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = {
|
||||||
|
(spec.category, spec.variant): spec
|
||||||
|
for spec in self.specs
|
||||||
|
}
|
||||||
|
|
||||||
"page": {
|
def asdict(self) -> Dict[str, Dict[str, str]]:
|
||||||
'': "model_eynollah_page_extraction_20250915",
|
return {
|
||||||
},
|
spec.category: {
|
||||||
|
spec.variant: spec.filename
|
||||||
|
}
|
||||||
|
for spec in self.specs
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: What is this commented out model?
|
def get(self, category: str, variant: str) -> EynollahModelSpec:
|
||||||
#?: "eynollah-main-regions-aug-scaling_20210425",
|
if category not in self.categories:
|
||||||
|
raise ValueError(f"Unknown category '{category}', must be one of {self.categories}")
|
||||||
|
if variant not in self.variants[category]:
|
||||||
|
raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}")
|
||||||
|
return self._index_category_variant[(category, variant)]
|
||||||
|
|
||||||
# early layout
|
DEFAULT_MODEL_SPECS = EynollahModelSpecSet([# {{{
|
||||||
"region": {
|
|
||||||
'': "eynollah-main-regions-ensembled_20210425",
|
|
||||||
'extract_only_images': "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
|
|
||||||
'light': "eynollah-main-regions_20220314",
|
|
||||||
},
|
|
||||||
|
|
||||||
# early layout, non-light, 2nd part
|
EynollahModelSpec(
|
||||||
"region_p2": {
|
category="enhancement",
|
||||||
'': "eynollah-main-regions-aug-rotation_20210425",
|
variant='',
|
||||||
},
|
filename="models_eynollah/eynollah-enhancement_20210425",
|
||||||
|
dist=dist_url("enhancement"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
# early layout, light, 1-or-2-column
|
EynollahModelSpec(
|
||||||
"region_1_2": {
|
category="binarization",
|
||||||
#'': "modelens_12sp_elay_0_3_4__3_6_n"
|
variant='',
|
||||||
#'': "modelens_earlylayout_12spaltige_2_3_5_6_7_8"
|
filename="models_eynollah/eynollah-binarization_20210425",
|
||||||
#'': "modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
|
dist=dist_url("binarization"),
|
||||||
#'': "modelens_1_2_4_5_early_lay_1_2_spaltige"
|
type=KerasModel,
|
||||||
#'': "model_3_eraly_layout_no_patches_1_2_spaltige"
|
),
|
||||||
'': "modelens_e_l_all_sp_0_1_2_3_4_171024"
|
|
||||||
},
|
|
||||||
|
|
||||||
# full layout / no patches
|
EynollahModelSpec(
|
||||||
"region_fl_np": {
|
category="binarization_multi_1",
|
||||||
#'': "modelens_full_lay_1_3_031124"
|
variant='',
|
||||||
#'': "modelens_full_lay_13__3_19_241024"
|
filename="models_eynollah/saved_model_2020_01_16/model_bin1",
|
||||||
#'': "model_full_lay_13_241024"
|
dist=dist_url("binarization"),
|
||||||
#'': "modelens_full_lay_13_17_231024"
|
type=KerasModel,
|
||||||
#'': "modelens_full_lay_1_2_221024"
|
),
|
||||||
#'': "eynollah-full-regions-1column_20210425"
|
|
||||||
'': "modelens_full_lay_1__4_3_091124"
|
|
||||||
},
|
|
||||||
|
|
||||||
# full layout / with patches
|
EynollahModelSpec(
|
||||||
"region_fl": {
|
category="binarization_multi_2",
|
||||||
#'': "eynollah-full-regions-3+column_20210425"
|
variant='',
|
||||||
#'': #"model_2_full_layout_new_trans"
|
filename="models_eynollah/saved_model_2020_01_16/model_bin2",
|
||||||
#'': "modelens_full_lay_1_3_031124"
|
dist=dist_url("binarization"),
|
||||||
#'': "modelens_full_lay_13__3_19_241024"
|
type=KerasModel,
|
||||||
#'': "model_full_lay_13_241024"
|
),
|
||||||
#'': "modelens_full_lay_13_17_231024"
|
|
||||||
#'': "modelens_full_lay_1_2_221024"
|
|
||||||
#'': "modelens_full_layout_24_till_28"
|
|
||||||
#'': "model_2_full_layout_new_trans"
|
|
||||||
'': "modelens_full_lay_1__4_3_091124",
|
|
||||||
},
|
|
||||||
|
|
||||||
"reading_order": {
|
EynollahModelSpec(
|
||||||
#'': "model_mb_ro_aug_ens_11"
|
category="binarization_multi_3",
|
||||||
#'': "model_step_3200000_mb_ro"
|
variant='',
|
||||||
#'': "model_ens_reading_order_machine_based"
|
filename="models_eynollah/saved_model_2020_01_16/model_bin3",
|
||||||
#'': "model_mb_ro_aug_ens_8"
|
dist=dist_url("binarization"),
|
||||||
#'': "model_ens_reading_order_machine_based"
|
type=KerasModel,
|
||||||
'': "model_eynollah_reading_order_20250824"
|
),
|
||||||
},
|
|
||||||
|
|
||||||
"textline": {
|
EynollahModelSpec(
|
||||||
#'light': "eynollah-textline_light_20210425"
|
category="binarization_multi_4",
|
||||||
'light': "modelens_textline_0_1__2_4_16092024",
|
variant='',
|
||||||
#'': "modelens_textline_1_4_16092024"
|
filename="models_eynollah/saved_model_2020_01_16/model_bin4",
|
||||||
#'': "model_textline_ens_3_4_5_6_artificial"
|
dist=dist_url("binarization"),
|
||||||
#'': "modelens_textline_1_3_4_20240915"
|
type=KerasModel,
|
||||||
#'': "model_textline_ens_3_4_5_6_artificial"
|
),
|
||||||
#'': "modelens_textline_9_12_13_14_15"
|
|
||||||
#'': "eynollah-textline_20210425"
|
|
||||||
'': "modelens_textline_0_1__2_4_16092024"
|
|
||||||
},
|
|
||||||
|
|
||||||
"table": {
|
EynollahModelSpec(
|
||||||
'light': "modelens_table_0t4_201124",
|
category="col_classifier",
|
||||||
'': "eynollah-tables_20210319",
|
variant='',
|
||||||
},
|
filename="models_eynollah/eynollah-column-classifier_20210425",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
"ocr": {
|
EynollahModelSpec(
|
||||||
'tr': "model_eynollah_ocr_trocr_20250919",
|
category="page",
|
||||||
'': "model_eynollah_ocr_cnnrnn_20250930",
|
variant='',
|
||||||
},
|
filename="models_eynollah/model_eynollah_page_extraction_20250915",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
'trocr_processor': {
|
EynollahModelSpec(
|
||||||
'': 'microsoft/trocr-base-printed',
|
category="region",
|
||||||
'htr': "microsoft/trocr-base-handwritten",
|
variant='',
|
||||||
},
|
filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
'num_to_char': {
|
EynollahModelSpec(
|
||||||
'': 'characters_org.txt'
|
category="region",
|
||||||
},
|
variant='extract_only_images',
|
||||||
|
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
'characters': {
|
EynollahModelSpec(
|
||||||
'': 'characters_org.txt'
|
category="region",
|
||||||
},
|
variant='light',
|
||||||
|
filename="models_eynollah/eynollah-main-regions_20220314",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
help="early layout",
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
}
|
EynollahModelSpec(
|
||||||
|
category="region_p2",
|
||||||
|
variant='',
|
||||||
|
filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
help="early layout, non-light, 2nd part",
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="region_1_2",
|
||||||
|
variant='',
|
||||||
|
#filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n",
|
||||||
|
#filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8",
|
||||||
|
#filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18",
|
||||||
|
#filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige",
|
||||||
|
#filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige",
|
||||||
|
filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
help="early layout, light, 1-or-2-column",
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="region_fl_np",
|
||||||
|
variant='',
|
||||||
|
#'filename="models_eynollah/modelens_full_lay_1_3_031124",
|
||||||
|
#'filename="models_eynollah/modelens_full_lay_13__3_19_241024",
|
||||||
|
#'filename="models_eynollah/model_full_lay_13_241024",
|
||||||
|
#'filename="models_eynollah/modelens_full_lay_13_17_231024",
|
||||||
|
#'filename="models_eynollah/modelens_full_lay_1_2_221024",
|
||||||
|
#'filename="models_eynollah/eynollah-full-regions-1column_20210425",
|
||||||
|
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
help="full layout / no patches",
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
# FIXME: Why is region_fl and region_fl_np the same model?
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="region_fl",
|
||||||
|
variant='',
|
||||||
|
# filename="models_eynollah/eynollah-full-regions-3+column_20210425",
|
||||||
|
# filename="models_eynollah/model_2_full_layout_new_trans",
|
||||||
|
# filename="models_eynollah/modelens_full_lay_1_3_031124",
|
||||||
|
# filename="models_eynollah/modelens_full_lay_13__3_19_241024",
|
||||||
|
# filename="models_eynollah/model_full_lay_13_241024",
|
||||||
|
# filename="models_eynollah/modelens_full_lay_13_17_231024",
|
||||||
|
# filename="models_eynollah/modelens_full_lay_1_2_221024",
|
||||||
|
# filename="models_eynollah/modelens_full_layout_24_till_28",
|
||||||
|
# filename="models_eynollah/model_2_full_layout_new_trans",
|
||||||
|
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
help="full layout / with patches",
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="reading_order",
|
||||||
|
variant='',
|
||||||
|
#filename="models_eynollah/model_mb_ro_aug_ens_11",
|
||||||
|
#filename="models_eynollah/model_step_3200000_mb_ro",
|
||||||
|
#filename="models_eynollah/model_ens_reading_order_machine_based",
|
||||||
|
#filename="models_eynollah/model_mb_ro_aug_ens_8",
|
||||||
|
#filename="models_eynollah/model_ens_reading_order_machine_based",
|
||||||
|
filename="models_eynollah/model_eynollah_reading_order_20250824",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="textline",
|
||||||
|
variant='',
|
||||||
|
#filename="models_eynollah/modelens_textline_1_4_16092024",
|
||||||
|
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
|
||||||
|
#filename="models_eynollah/modelens_textline_1_3_4_20240915",
|
||||||
|
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
|
||||||
|
#filename="models_eynollah/modelens_textline_9_12_13_14_15",
|
||||||
|
#filename="models_eynollah/eynollah-textline_20210425",
|
||||||
|
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="textline",
|
||||||
|
variant='light',
|
||||||
|
#filename="models_eynollah/eynollah-textline_light_20210425",
|
||||||
|
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="table",
|
||||||
|
variant='',
|
||||||
|
filename="models_eynollah/eynollah-tables_20210319",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="table",
|
||||||
|
variant='light',
|
||||||
|
filename="models_eynollah/modelens_table_0t4_201124",
|
||||||
|
dist=dist_url("layout"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="ocr",
|
||||||
|
variant='',
|
||||||
|
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
|
||||||
|
dist=dist_url("ocr"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="num_to_char",
|
||||||
|
variant='',
|
||||||
|
filename="models_eynollah/characters_org.txt",
|
||||||
|
dist=dist_url("ocr"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="characters",
|
||||||
|
variant='',
|
||||||
|
filename="models_eynollah/characters_org.txt",
|
||||||
|
dist=dist_url("ocr"),
|
||||||
|
type=List,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="ocr",
|
||||||
|
variant='tr',
|
||||||
|
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
|
||||||
|
dist=dist_url("trocr"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="trocr_processor",
|
||||||
|
variant='',
|
||||||
|
filename="models_eynollah/microsoft/trocr-base-printed",
|
||||||
|
dist=dist_url("trocr"),
|
||||||
|
type=KerasModel,
|
||||||
|
),
|
||||||
|
|
||||||
|
EynollahModelSpec(
|
||||||
|
category="trocr_processor",
|
||||||
|
variant='htr',
|
||||||
|
filename="models_eynollah/microsoft/trocr-base-handwritten",
|
||||||
|
dist=dist_url("trocr"),
|
||||||
|
type=TrOCRProcessor,
|
||||||
|
),
|
||||||
|
|
||||||
|
])# }}}
|
||||||
|
|
||||||
class EynollahModelZoo():
|
class EynollahModelZoo():
|
||||||
"""
|
"""
|
||||||
Wrapper class that handles storage and loading of models for all eynollah runners.
|
Wrapper class that handles storage and loading of models for all eynollah runners.
|
||||||
"""
|
"""
|
||||||
model_basedir: Path
|
model_basedir: Path
|
||||||
model_versions: dict
|
specs: EynollahModelSpecSet
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -157,10 +332,10 @@ class EynollahModelZoo():
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_basedir = Path(basedir)
|
self.model_basedir = Path(basedir)
|
||||||
self.logger = logging.getLogger('eynollah.model_zoo')
|
self.logger = logging.getLogger('eynollah.model_zoo')
|
||||||
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
|
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
|
||||||
if model_overrides:
|
if model_overrides:
|
||||||
self.override_models(*model_overrides)
|
self.override_models(*model_overrides)
|
||||||
self._loaded: Dict[str, SomeEynollahModel] = {}
|
self._loaded: Dict[str, AnyModel] = {}
|
||||||
|
|
||||||
def override_models(
|
def override_models(
|
||||||
self,
|
self,
|
||||||
|
|
@ -170,39 +345,24 @@ class EynollahModelZoo():
|
||||||
Override the default model versions
|
Override the default model versions
|
||||||
"""
|
"""
|
||||||
for model_category, model_variant, model_filename in model_overrides:
|
for model_category, model_variant, model_filename in model_overrides:
|
||||||
if model_category not in DEFAULT_MODEL_VERSIONS:
|
spec = self.specs.get(model_category, model_variant)
|
||||||
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
|
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
|
||||||
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
|
self.specs.get(model_category, model_variant).filename = model_filename
|
||||||
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
|
|
||||||
self.logger.warning(
|
|
||||||
"Overriding default model %s ('%s' variant) from %s to %s",
|
|
||||||
model_category,
|
|
||||||
model_variant,
|
|
||||||
DEFAULT_MODEL_VERSIONS[model_category][model_variant],
|
|
||||||
model_filename
|
|
||||||
)
|
|
||||||
self.model_versions[model_category][model_variant] = model_filename
|
|
||||||
|
|
||||||
def model_path(
|
def model_path(
|
||||||
self,
|
self,
|
||||||
model_category: str,
|
model_category: str,
|
||||||
model_variant: str = '',
|
model_variant: str = '',
|
||||||
model_filename: str = '',
|
|
||||||
absolute: bool = True,
|
absolute: bool = True,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Translate model_{type,variant,filename} tuple into an absolute (or relative) Path
|
Translate model_{type,variant} tuple into an absolute (or relative) Path
|
||||||
"""
|
"""
|
||||||
if model_category not in DEFAULT_MODEL_VERSIONS:
|
spec = self.specs.get(model_category, model_variant)
|
||||||
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
|
if not Path(spec.filename).is_absolute() and absolute:
|
||||||
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
|
model_path = Path(self.model_basedir).joinpath(spec.filename)
|
||||||
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
|
|
||||||
if not model_filename:
|
|
||||||
model_filename = DEFAULT_MODEL_VERSIONS[model_category][model_variant]
|
|
||||||
if not Path(model_filename).is_absolute() and absolute:
|
|
||||||
model_path = Path(self.model_basedir).joinpath(model_filename)
|
|
||||||
else:
|
else:
|
||||||
model_path = Path(model_filename)
|
model_path = Path(spec.filename)
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
def load_models(
|
def load_models(
|
||||||
|
|
@ -224,12 +384,11 @@ class EynollahModelZoo():
|
||||||
self,
|
self,
|
||||||
model_category: str,
|
model_category: str,
|
||||||
model_variant: str = '',
|
model_variant: str = '',
|
||||||
model_filename: str = '',
|
) -> AnyModel:
|
||||||
) -> SomeEynollahModel:
|
|
||||||
"""
|
"""
|
||||||
Load any model
|
Load any model
|
||||||
"""
|
"""
|
||||||
model_path = self.model_path(model_category, model_variant, model_filename)
|
model_path = self.model_path(model_category, model_variant)
|
||||||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||||||
# prefer SavedModel over HDF5 format if it exists
|
# prefer SavedModel over HDF5 format if it exists
|
||||||
model_path = Path(model_path.stem)
|
model_path = Path(model_path.stem)
|
||||||
|
|
@ -259,17 +418,17 @@ class EynollahModelZoo():
|
||||||
assert isinstance(ret, model_type)
|
assert isinstance(ret, model_type)
|
||||||
return ret # type: ignore # FIXME: convince typing that we're returning generic type
|
return ret # type: ignore # FIXME: convince typing that we're returning generic type
|
||||||
|
|
||||||
def _load_ocr_model(self, variant: str) -> SomeEynollahModel:
|
def _load_ocr_model(self, variant: str) -> AnyModel:
|
||||||
"""
|
"""
|
||||||
Load OCR model
|
Load OCR model
|
||||||
"""
|
"""
|
||||||
ocr_model_dir = Path(self.model_basedir, self.model_versions["ocr"][variant])
|
ocr_model_dir = self.model_path('ocr', variant)
|
||||||
if variant == 'tr':
|
if variant == 'tr':
|
||||||
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
||||||
else:
|
else:
|
||||||
ocr_model = load_model(ocr_model_dir, compile=False)
|
ocr_model = load_model(ocr_model_dir, compile=False)
|
||||||
assert isinstance(ocr_model, Model)
|
assert isinstance(ocr_model, KerasModel)
|
||||||
return Model(
|
return KerasModel(
|
||||||
ocr_model.get_layer(name = "image").input, # type: ignore
|
ocr_model.get_layer(name = "image").input, # type: ignore
|
||||||
ocr_model.get_layer(name = "dense2").output) # type: ignore
|
ocr_model.get_layer(name = "dense2").output) # type: ignore
|
||||||
|
|
||||||
|
|
@ -295,7 +454,7 @@ class EynollahModelZoo():
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(json.dumps({
|
return str(json.dumps({
|
||||||
'basedir': str(self.model_basedir),
|
'basedir': str(self.model_basedir),
|
||||||
'versions': self.model_versions,
|
'versions': self.specs,
|
||||||
}, indent=2))
|
}, indent=2))
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue