factor model loading in Eynollah to EynollahModelZoo

This commit is contained in:
kba 2025-10-20 18:34:44 +02:00
parent 38c028c6b5
commit a850ef39ea
6 changed files with 389 additions and 171 deletions

View file

@ -1,16 +1,57 @@
from dataclasses import dataclass
import sys import sys
import os
import click import click
import logging import logging
from typing import Tuple, List
from ocrd_utils import initLogging, getLevelName, getLogger from ocrd_utils import initLogging, getLevelName, getLogger
from eynollah.eynollah import Eynollah, Eynollah_ocr from eynollah.eynollah import Eynollah, Eynollah_ocr
from eynollah.sbb_binarize import SbbBinarizer from eynollah.sbb_binarize import SbbBinarizer
from eynollah.image_enhancer import Enhancer from eynollah.image_enhancer import Enhancer
from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout
from eynollah.model_zoo import EynollahModelZoo
@dataclass
class EynollahCliCtx():
model_basedir: str
model_overrides: List[Tuple[str, str, str]]
@click.group() @click.group()
def main(): def main():
pass pass
@main.command('list-models')
@click.option(
"--model",
"-m",
'model_basedir',
help="directory of models",
type=click.Path(exists=True, file_okay=False),
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
required=True,
)
@click.option(
"--model-overrides",
"-mv",
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
type=(str, str, str),
multiple=True,
)
@click.pass_context
def list_models(
ctx,
model_basedir: str,
model_overrides: List[Tuple[str, str, str]],
):
"""
List all the models in the zoo
"""
ctx.obj = EynollahCliCtx(
model_basedir=model_basedir,
model_overrides=model_overrides
)
print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides))
@main.command() @main.command()
@click.option( @click.option(
"--input", "--input",
@ -198,15 +239,17 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
@click.option( @click.option(
"--model", "--model",
"-m", "-m",
'model_basedir',
help="directory of models", help="directory of models",
type=click.Path(exists=True, file_okay=False), type=click.Path(exists=True, file_okay=False),
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
required=True, required=True,
) )
@click.option( @click.option(
"--model_version", "--model_version",
"-mv", "-mv",
help="override default versions of model categories", help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
type=(str, str), type=(str, str, str),
multiple=True, multiple=True,
) )
@click.option( @click.option(
@ -411,7 +454,7 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah( eynollah = Eynollah(
model, model,
model_versions=model_version, model_overrides=model_version,
extract_only_images=extract_only_images, extract_only_images=extract_only_images,
enable_plotting=enable_plotting, enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement, allow_enhancement=allow_enhancement,

View file

@ -2,12 +2,15 @@
# pylint: disable=too-many-locals,wrong-import-position,too-many-lines,too-many-statements,chained-comparison,fixme,broad-except,c-extension-no-member # pylint: disable=too-many-locals,wrong-import-position,too-many-lines,too-many-statements,chained-comparison,fixme,broad-except,c-extension-no-member
# pylint: disable=too-many-public-methods,too-many-arguments,too-many-instance-attributes,too-many-public-methods, # pylint: disable=too-many-public-methods,too-many-arguments,too-many-instance-attributes,too-many-public-methods,
# pylint: disable=consider-using-enumerate # pylint: disable=consider-using-enumerate
# pyright: reportUnnecessaryTypeIgnoreComment=true
# pyright: reportPossiblyUnboundVariable=false
""" """
document layout analysis (segmentation) with output in PAGE-XML document layout analysis (segmentation) with output in PAGE-XML
""" """
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files # cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
import sys import sys
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
import importlib_resources import importlib_resources
else: else:
@ -19,7 +22,7 @@ import math
import os import os
import sys import sys
import time import time
from typing import Dict, List, Optional, Tuple from typing import Dict, Union,List, Optional, Tuple
import atexit import atexit
import warnings import warnings
from functools import partial from functools import partial
@ -58,8 +61,7 @@ except ImportError:
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1' #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf_disable_interactive_logs() tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend as K from keras.models import load_model
from tensorflow.keras.models import load_model
tf.get_logger().setLevel("ERROR") tf.get_logger().setLevel("ERROR")
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
# use tf1 compatibility for keras backend # use tf1 compatibility for keras backend
@ -67,6 +69,7 @@ from tensorflow.compat.v1.keras.backend import set_session
from tensorflow.keras import layers from tensorflow.keras import layers
from tensorflow.keras.layers import StringLookup from tensorflow.keras.layers import StringLookup
from .model_zoo import EynollahModelZoo
from .utils.contour import ( from .utils.contour import (
filter_contours_area_of_image, filter_contours_area_of_image,
filter_contours_area_of_image_tables, filter_contours_area_of_image_tables,
@ -155,59 +158,12 @@ patch_size = 1
num_patches =21*21#14*14#28*28#14*14#28*28 num_patches =21*21#14*14#28*28#14*14#28*28
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class Eynollah: class Eynollah:
def __init__( def __init__(
self, self,
dir_models : str, dir_models : str,
model_versions: List[Tuple[str, str]] = [], model_overrides: List[Tuple[str, str, str]] = [],
extract_only_images : bool =False, extract_only_images : bool =False,
enable_plotting : bool = False, enable_plotting : bool = False,
allow_enhancement : bool = False, allow_enhancement : bool = False,
@ -232,6 +188,7 @@ class Eynollah:
skip_layout_and_reading_order : bool = False, skip_layout_and_reading_order : bool = False,
): ):
self.logger = getLogger('eynollah') self.logger = getLogger('eynollah')
self.model_zoo = EynollahModelZoo(basedir=dir_models)
self.plotter = None self.plotter = None
if skip_layout_and_reading_order: if skip_layout_and_reading_order:
@ -297,93 +254,13 @@ class Eynollah:
self.logger.warning("no GPU device available") self.logger.warning("no GPU device available")
self.logger.info("Loading models...") self.logger.info("Loading models...")
self.setup_models(dir_models, model_versions) self.setup_models(*model_overrides)
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
@staticmethod def setup_models(self, *model_overrides: Tuple[str, str, str]):
def our_load_model(model_file, basedir=""):
if basedir:
model_file = os.path.join(basedir, model_file)
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []):
self.model_versions = {
"enhancement": "eynollah-enhancement_20210425",
"binarization": "eynollah-binarization_20210425",
"col_classifier": "eynollah-column-classifier_20210425",
"page": "model_eynollah_page_extraction_20250915",
#?: "eynollah-main-regions-aug-scaling_20210425",
"region": ( # early layout
"eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else
"eynollah-main-regions_20220314" if self.light_version else
"eynollah-main-regions-ensembled_20210425"),
"region_p2": ( # early layout, non-light, 2nd part
"eynollah-main-regions-aug-rotation_20210425"),
"region_1_2": ( # early layout, light, 1-or-2-column
#"modelens_12sp_elay_0_3_4__3_6_n"
#"modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#"modelens_1_2_4_5_early_lay_1_2_spaltige"
#"model_3_eraly_layout_no_patches_1_2_spaltige"
"modelens_e_l_all_sp_0_1_2_3_4_171024"),
"region_fl_np": ( # full layout / no patches
#"modelens_full_lay_1_3_031124"
#"modelens_full_lay_13__3_19_241024"
#"model_full_lay_13_241024"
#"modelens_full_lay_13_17_231024"
#"modelens_full_lay_1_2_221024"
#"eynollah-full-regions-1column_20210425"
"modelens_full_lay_1__4_3_091124"),
"region_fl": ( # full layout / with patches
#"eynollah-full-regions-3+column_20210425"
##"model_2_full_layout_new_trans"
#"modelens_full_lay_1_3_031124"
#"modelens_full_lay_13__3_19_241024"
#"model_full_lay_13_241024"
#"modelens_full_lay_13_17_231024"
#"modelens_full_lay_1_2_221024"
#"modelens_full_layout_24_till_28"
#"model_2_full_layout_new_trans"
"modelens_full_lay_1__4_3_091124"),
"reading_order": (
#"model_mb_ro_aug_ens_11"
#"model_step_3200000_mb_ro"
#"model_ens_reading_order_machine_based"
#"model_mb_ro_aug_ens_8"
#"model_ens_reading_order_machine_based"
"model_eynollah_reading_order_20250824"),
"textline": (
#"modelens_textline_1_4_16092024"
#"model_textline_ens_3_4_5_6_artificial"
#"modelens_textline_1_3_4_20240915"
#"model_textline_ens_3_4_5_6_artificial"
#"modelens_textline_9_12_13_14_15"
#"eynollah-textline_light_20210425"
"modelens_textline_0_1__2_4_16092024" if self.textline_light else
#"eynollah-textline_20210425"
"modelens_textline_0_1__2_4_16092024"),
"table": (
None if not self.tables else
"modelens_table_0t4_201124" if self.light_version else
"eynollah-tables_20210319"),
"ocr": (
None if not self.ocr else
"model_eynollah_ocr_trocr_20250919" if self.tr else
"model_eynollah_ocr_cnnrnn_20250930")
}
# override defaults from CLI # override defaults from CLI
for key, val in model_versions: self.model_zoo.override_models(*model_overrides)
assert key in self.model_versions, "unknown model category '%s'" % key
self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val)
self.model_versions[key] = val
# load models, depending on modes # load models, depending on modes
# (note: loading too many models can cause OOM on GPU/CUDA, # (note: loading too many models can cause OOM on GPU/CUDA,
# thus, we try set up the minimal configuration for the current mode) # thus, we try set up the minimal configuration for the current mode)
@ -391,10 +268,10 @@ class Eynollah:
"col_classifier", "col_classifier",
"binarization", "binarization",
"page", "page",
"region" ("region", 'extract_only_images' if self.extract_only_images else 'light' if self.light_version else '')
] ]
if not self.extract_only_images: if not self.extract_only_images:
loadable.append("textline") loadable.append(("textline", 'light' if self.light_version else ''))
if self.light_version: if self.light_version:
loadable.append("region_1_2") loadable.append("region_1_2")
else: else:
@ -407,38 +284,24 @@ class Eynollah:
if self.reading_order_machine_based: if self.reading_order_machine_based:
loadable.append("reading_order") loadable.append("reading_order")
if self.tables: if self.tables:
loadable.append("table") loadable.append(("table", 'light' if self.light_version else ''))
self.models = {name: self.our_load_model(self.model_versions[name], basedir)
for name in loadable
}
if self.ocr: if self.ocr:
ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"])
if self.tr: if self.tr:
self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) loadable.append(('ocr', 'tr'))
loadable.append(('ocr_tr_processor', 'tr'))
# TODO why here and why only for tr?
if torch.cuda.is_available(): if torch.cuda.is_available():
self.logger.info("Using GPU acceleration") self.logger.info("Using GPU acceleration")
self.device = torch.device("cuda:0") self.device = torch.device("cuda:0")
else: else:
self.logger.info("Using CPU processing") self.logger.info("Using CPU processing")
self.device = torch.device("cpu") self.device = torch.device("cpu")
#self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
else: else:
ocr_model = load_model(ocr_model_dir, compile=False) loadable.append('ocr')
self.models["ocr"] = tf.keras.models.Model( loadable.append('num_to_char')
ocr_model.get_layer(name = "image").input,
ocr_model.get_layer(name = "dense2").output) self.models = self.model_zoo.load_models(*loadable)
with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file:
characters = json.load(config_file)
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
self.num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
def __del__(self): def __del__(self):
if hasattr(self, 'executor') and getattr(self, 'executor'): if hasattr(self, 'executor') and getattr(self, 'executor'):
@ -4261,7 +4124,7 @@ class Eynollah:
gc.collect() gc.collect()
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)), image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)),
self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True) self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], textline_light=True)
else: else:
ocr_all_textlines = None ocr_all_textlines = None
@ -4770,27 +4633,27 @@ class Eynollah:
if len(all_found_textline_polygons): if len(all_found_textline_polygons):
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, all_box_coord, image_page, all_found_textline_polygons, all_box_coord,
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
if len(all_found_textline_polygons_marginals_left): if len(all_found_textline_polygons_marginals_left):
ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left, image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left,
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
if len(all_found_textline_polygons_marginals_right): if len(all_found_textline_polygons_marginals_right):
ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right, image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right,
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
if self.full_layout and len(all_found_textline_polygons): if self.full_layout and len(all_found_textline_polygons):
ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_h, all_box_coord_h, image_page, all_found_textline_polygons_h, all_box_coord_h,
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
if self.full_layout and len(polygons_of_drop_capitals): if self.full_layout and len(polygons_of_drop_capitals):
ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines(
image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)), image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)),
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
else: else:
if self.light_version: if self.light_version:
@ -4839,7 +4702,7 @@ class Eynollah:
img_croped = img_poly_on_img[y:y+h, x:x+w, :] img_croped = img_poly_on_img[y:y+h, x:x+w, :]
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
text_ocr = self.return_ocr_of_textline_without_common_section( text_ocr = self.return_ocr_of_textline_without_common_section(
img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot) img_croped, self.models["ocr"], self.models['ocr_tr_processor'], self.device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr) ocr_textline_in_textregion.append(text_ocr)
ind_tot = ind_tot +1 ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion) ocr_all_textlines.append(ocr_textline_in_textregion)

View file

@ -22,7 +22,7 @@ from .utils import (
is_image_filename, is_image_filename,
crop_image_inside_box crop_image_inside_box
) )
from .eynollah import PatchEncoder, Patches from .patch_encoder import PatchEncoder, Patches
DPI_THRESHOLD = 298 DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8) KERNEL = np.ones((5, 5), np.uint8)

View file

@ -23,7 +23,7 @@ from .utils.contour import (
return_parent_contours, return_parent_contours,
) )
from .utils import is_xml_filename from .utils import is_xml_filename
from .eynollah import PatchEncoder, Patches from .patch_encoder import PatchEncoder, Patches
DPI_THRESHOLD = 298 DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8) KERNEL = np.ones((5, 5), np.uint8)

260
src/eynollah/model_zoo.py Normal file
View file

@ -0,0 +1,260 @@
from dataclasses import dataclass
import json
import logging
from pathlib import Path
from types import MappingProxyType
from typing import Dict, Literal, Optional, Tuple, List, Union
from copy import deepcopy
from keras.layers import StringLookup
from keras.models import Model, load_model
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from eynollah.patch_encoder import PatchEncoder, Patches
# Dict mapping model_category to dict mapping variant (default is '') to Path
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
"enhancement": {
'': "eynollah-enhancement_20210425"
},
"binarization": {
'': "eynollah-binarization_20210425"
},
"col_classifier": {
'': "eynollah-column-classifier_20210425",
},
"page": {
'': "model_eynollah_page_extraction_20250915",
},
# TODO: What is this commented out model?
#?: "eynollah-main-regions-aug-scaling_20210425",
# early layout
"region": {
'': "eynollah-main-regions-ensembled_20210425",
'extract_only_images': "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
'light': "eynollah-main-regions_20220314",
},
# early layout, non-light, 2nd part
"region_p2": {
'': "eynollah-main-regions-aug-rotation_20210425",
},
# early layout, light, 1-or-2-column
"region_1_2": {
#'': "modelens_12sp_elay_0_3_4__3_6_n"
#'': "modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#'': "modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#'': "modelens_1_2_4_5_early_lay_1_2_spaltige"
#'': "model_3_eraly_layout_no_patches_1_2_spaltige"
'': "modelens_e_l_all_sp_0_1_2_3_4_171024"
},
# full layout / no patches
"region_fl_np": {
#'': "modelens_full_lay_1_3_031124"
#'': "modelens_full_lay_13__3_19_241024"
#'': "model_full_lay_13_241024"
#'': "modelens_full_lay_13_17_231024"
#'': "modelens_full_lay_1_2_221024"
#'': "eynollah-full-regions-1column_20210425"
'': "modelens_full_lay_1__4_3_091124"
},
# full layout / with patches
"region_fl": {
#'': "eynollah-full-regions-3+column_20210425"
#'': #"model_2_full_layout_new_trans"
#'': "modelens_full_lay_1_3_031124"
#'': "modelens_full_lay_13__3_19_241024"
#'': "model_full_lay_13_241024"
#'': "modelens_full_lay_13_17_231024"
#'': "modelens_full_lay_1_2_221024"
#'': "modelens_full_layout_24_till_28"
#'': "model_2_full_layout_new_trans"
'': "modelens_full_lay_1__4_3_091124",
},
"reading_order": {
#'': "model_mb_ro_aug_ens_11"
#'': "model_step_3200000_mb_ro"
#'': "model_ens_reading_order_machine_based"
#'': "model_mb_ro_aug_ens_8"
#'': "model_ens_reading_order_machine_based"
'': "model_eynollah_reading_order_20250824"
},
"textline": {
#'light': "eynollah-textline_light_20210425"
'light': "modelens_textline_0_1__2_4_16092024",
#'': "modelens_textline_1_4_16092024"
#'': "model_textline_ens_3_4_5_6_artificial"
#'': "modelens_textline_1_3_4_20240915"
#'': "model_textline_ens_3_4_5_6_artificial"
#'': "modelens_textline_9_12_13_14_15"
#'': "eynollah-textline_20210425"
'': "modelens_textline_0_1__2_4_16092024"
},
"table": {
'light': "modelens_table_0t4_201124",
'': "eynollah-tables_20210319",
},
"ocr": {
'tr': "model_eynollah_ocr_trocr_20250919",
'': "model_eynollah_ocr_cnnrnn_20250930",
},
'ocr_tr_processor': {
'': 'microsoft/trocr-base-printed',
'htr': "microsoft/trocr-base-handwritten",
},
'num_to_char': {
'': 'model_eynollah_ocr_cnnrnn_20250930/characters_org.txt'
},
}
class EynollahModelZoo():
"""
Wrapper class that handles storage and loading of models for all eynollah runners.
"""
model_basedir: Path
model_versions: dict
def __init__(
self,
basedir: str,
model_overrides: List[Tuple[str, str, str]],
) -> None:
self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo')
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
if model_overrides:
self.override_models(*model_overrides)
def override_models(self, *model_overrides: Tuple[str, str, str]):
"""
Override the default model versions
"""
for model_category, model_variant, model_filename in model_overrides:
if model_category not in DEFAULT_MODEL_VERSIONS:
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
self.logger.warning(
"Overriding default model %s ('%s' variant) from %s to %s",
model_category,
model_variant,
DEFAULT_MODEL_VERSIONS[model_category][model_variant],
model_filename
)
self.model_versions[model_category][model_variant] = model_filename
def model_path(
self,
model_category: str,
model_variant: str = '',
model_filename: str = '',
absolute: bool = True,
) -> Path:
"""
Translate model_{type,variant,filename} tuple into an absolute (or relative) Path
"""
if model_category not in DEFAULT_MODEL_VERSIONS:
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
if not model_filename:
model_filename = DEFAULT_MODEL_VERSIONS[model_category][model_variant]
if not Path(model_filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(model_filename)
else:
model_path = Path(model_filename)
return model_path
def load_models(
self,
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
) -> Dict:
"""
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
"""
ret = {}
for load_args in all_load_args:
if isinstance(load_args, str):
ret[load_args] = self.load_model(load_args)
else:
ret[load_args[0]] = self.load_model(*load_args)
return ret
def load_model(
self,
model_category: str,
model_variant: str = '',
model_filename: str = '',
) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
"""
Load any model
"""
model_path = self.model_path(model_category, model_variant, model_filename)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
if model_category == 'ocr':
model = self._load_ocr_model(variant=model_variant)
elif model_category == 'num_to_char':
model = self._load_num_to_char()
elif model_category == 'tr_processor':
return TrOCRProcessor.from_pretrained(self.model_path(...))
else:
try:
model = load_model(model_path, compile=False)
except Exception as e:
self.logger.exception(e)
model = load_model(model_path, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model # type: ignore
def _load_ocr_model(self, variant: str) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
"""
Load OCR model
"""
ocr_model_dir = Path(self.model_basedir, self.model_versions["ocr"][variant])
if variant == 'tr':
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
else:
ocr_model = load_model(ocr_model_dir, compile=False)
assert isinstance(ocr_model, Model)
return Model(
ocr_model.get_layer(name = "image").input, # type: ignore
ocr_model.get_layer(name = "dense2").output) # type: ignore
def _load_num_to_char(self):
"""
Load decoder for OCR
"""
with open(self.model_path('ocr') / self.model_path('ocr', 'num_to_char', absolute=False), "r") as config_file:
characters = json.load(config_file)
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
return StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
def __str__(self):
return str(json.dumps({
'basedir': str(self.model_basedir),
'versions': self.model_versions,
}, indent=2))

View file

@ -0,0 +1,52 @@
from keras import layers
import tensorflow as tf
projection_dim = 64
patch_size = 1
num_patches =21*21#14*14#28*28#14*14#28*28
class PatchEncoder(layers.Layer):
def __init__(self):
super().__init__()
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
def call(self, patch):
positions = tf.range(start=0, limit=num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config