factor model loading in Eynollah to EynollahModelZoo

This commit is contained in:
kba 2025-10-20 18:34:44 +02:00
parent 38c028c6b5
commit a850ef39ea
6 changed files with 389 additions and 171 deletions

View file

@ -1,16 +1,57 @@
from dataclasses import dataclass
import sys
import os
import click
import logging
from typing import Tuple, List
from ocrd_utils import initLogging, getLevelName, getLogger
from eynollah.eynollah import Eynollah, Eynollah_ocr
from eynollah.sbb_binarize import SbbBinarizer
from eynollah.image_enhancer import Enhancer
from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout
from eynollah.model_zoo import EynollahModelZoo
@dataclass
class EynollahCliCtx():
model_basedir: str
model_overrides: List[Tuple[str, str, str]]
@click.group()
def main():
pass
@main.command('list-models')
@click.option(
"--model",
"-m",
'model_basedir',
help="directory of models",
type=click.Path(exists=True, file_okay=False),
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
required=True,
)
@click.option(
"--model-overrides",
"-mv",
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
type=(str, str, str),
multiple=True,
)
@click.pass_context
def list_models(
ctx,
model_basedir: str,
model_overrides: List[Tuple[str, str, str]],
):
"""
List all the models in the zoo
"""
ctx.obj = EynollahCliCtx(
model_basedir=model_basedir,
model_overrides=model_overrides
)
print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides))
@main.command()
@click.option(
"--input",
@ -198,15 +239,17 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
@click.option(
"--model",
"-m",
'model_basedir',
help="directory of models",
type=click.Path(exists=True, file_okay=False),
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
required=True,
)
@click.option(
"--model_version",
"-mv",
help="override default versions of model categories",
type=(str, str),
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
type=(str, str, str),
multiple=True,
)
@click.option(
@ -411,7 +454,7 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah(
model,
model_versions=model_version,
model_overrides=model_version,
extract_only_images=extract_only_images,
enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement,

View file

@ -2,12 +2,15 @@
# pylint: disable=too-many-locals,wrong-import-position,too-many-lines,too-many-statements,chained-comparison,fixme,broad-except,c-extension-no-member
# pylint: disable=too-many-public-methods,too-many-arguments,too-many-instance-attributes,too-many-public-methods,
# pylint: disable=consider-using-enumerate
# pyright: reportUnnecessaryTypeIgnoreComment=true
# pyright: reportPossiblyUnboundVariable=false
"""
document layout analysis (segmentation) with output in PAGE-XML
"""
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
import sys
if sys.version_info < (3, 10):
import importlib_resources
else:
@ -19,7 +22,7 @@ import math
import os
import sys
import time
from typing import Dict, List, Optional, Tuple
from typing import Dict, Union,List, Optional, Tuple
import atexit
import warnings
from functools import partial
@ -58,8 +61,7 @@ except ImportError:
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf_disable_interactive_logs()
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.keras.models import load_model
from keras.models import load_model
tf.get_logger().setLevel("ERROR")
warnings.filterwarnings("ignore")
# use tf1 compatibility for keras backend
@ -67,6 +69,7 @@ from tensorflow.compat.v1.keras.backend import set_session
from tensorflow.keras import layers
from tensorflow.keras.layers import StringLookup
from .model_zoo import EynollahModelZoo
from .utils.contour import (
filter_contours_area_of_image,
filter_contours_area_of_image_tables,
@ -155,59 +158,12 @@ patch_size = 1
num_patches =21*21#14*14#28*28#14*14#28*28
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class Eynollah:
def __init__(
self,
dir_models : str,
model_versions: List[Tuple[str, str]] = [],
model_overrides: List[Tuple[str, str, str]] = [],
extract_only_images : bool =False,
enable_plotting : bool = False,
allow_enhancement : bool = False,
@ -232,6 +188,7 @@ class Eynollah:
skip_layout_and_reading_order : bool = False,
):
self.logger = getLogger('eynollah')
self.model_zoo = EynollahModelZoo(basedir=dir_models)
self.plotter = None
if skip_layout_and_reading_order:
@ -297,93 +254,13 @@ class Eynollah:
self.logger.warning("no GPU device available")
self.logger.info("Loading models...")
self.setup_models(dir_models, model_versions)
self.setup_models(*model_overrides)
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
@staticmethod
def our_load_model(model_file, basedir=""):
if basedir:
model_file = os.path.join(basedir, model_file)
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []):
self.model_versions = {
"enhancement": "eynollah-enhancement_20210425",
"binarization": "eynollah-binarization_20210425",
"col_classifier": "eynollah-column-classifier_20210425",
"page": "model_eynollah_page_extraction_20250915",
#?: "eynollah-main-regions-aug-scaling_20210425",
"region": ( # early layout
"eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else
"eynollah-main-regions_20220314" if self.light_version else
"eynollah-main-regions-ensembled_20210425"),
"region_p2": ( # early layout, non-light, 2nd part
"eynollah-main-regions-aug-rotation_20210425"),
"region_1_2": ( # early layout, light, 1-or-2-column
#"modelens_12sp_elay_0_3_4__3_6_n"
#"modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#"modelens_1_2_4_5_early_lay_1_2_spaltige"
#"model_3_eraly_layout_no_patches_1_2_spaltige"
"modelens_e_l_all_sp_0_1_2_3_4_171024"),
"region_fl_np": ( # full layout / no patches
#"modelens_full_lay_1_3_031124"
#"modelens_full_lay_13__3_19_241024"
#"model_full_lay_13_241024"
#"modelens_full_lay_13_17_231024"
#"modelens_full_lay_1_2_221024"
#"eynollah-full-regions-1column_20210425"
"modelens_full_lay_1__4_3_091124"),
"region_fl": ( # full layout / with patches
#"eynollah-full-regions-3+column_20210425"
##"model_2_full_layout_new_trans"
#"modelens_full_lay_1_3_031124"
#"modelens_full_lay_13__3_19_241024"
#"model_full_lay_13_241024"
#"modelens_full_lay_13_17_231024"
#"modelens_full_lay_1_2_221024"
#"modelens_full_layout_24_till_28"
#"model_2_full_layout_new_trans"
"modelens_full_lay_1__4_3_091124"),
"reading_order": (
#"model_mb_ro_aug_ens_11"
#"model_step_3200000_mb_ro"
#"model_ens_reading_order_machine_based"
#"model_mb_ro_aug_ens_8"
#"model_ens_reading_order_machine_based"
"model_eynollah_reading_order_20250824"),
"textline": (
#"modelens_textline_1_4_16092024"
#"model_textline_ens_3_4_5_6_artificial"
#"modelens_textline_1_3_4_20240915"
#"model_textline_ens_3_4_5_6_artificial"
#"modelens_textline_9_12_13_14_15"
#"eynollah-textline_light_20210425"
"modelens_textline_0_1__2_4_16092024" if self.textline_light else
#"eynollah-textline_20210425"
"modelens_textline_0_1__2_4_16092024"),
"table": (
None if not self.tables else
"modelens_table_0t4_201124" if self.light_version else
"eynollah-tables_20210319"),
"ocr": (
None if not self.ocr else
"model_eynollah_ocr_trocr_20250919" if self.tr else
"model_eynollah_ocr_cnnrnn_20250930")
}
def setup_models(self, *model_overrides: Tuple[str, str, str]):
# override defaults from CLI
for key, val in model_versions:
assert key in self.model_versions, "unknown model category '%s'" % key
self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val)
self.model_versions[key] = val
self.model_zoo.override_models(*model_overrides)
# load models, depending on modes
# (note: loading too many models can cause OOM on GPU/CUDA,
# thus, we try set up the minimal configuration for the current mode)
@ -391,10 +268,10 @@ class Eynollah:
"col_classifier",
"binarization",
"page",
"region"
("region", 'extract_only_images' if self.extract_only_images else 'light' if self.light_version else '')
]
if not self.extract_only_images:
loadable.append("textline")
loadable.append(("textline", 'light' if self.light_version else ''))
if self.light_version:
loadable.append("region_1_2")
else:
@ -407,38 +284,24 @@ class Eynollah:
if self.reading_order_machine_based:
loadable.append("reading_order")
if self.tables:
loadable.append("table")
self.models = {name: self.our_load_model(self.model_versions[name], basedir)
for name in loadable
}
loadable.append(("table", 'light' if self.light_version else ''))
if self.ocr:
ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"])
if self.tr:
self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
loadable.append(('ocr', 'tr'))
loadable.append(('ocr_tr_processor', 'tr'))
# TODO why here and why only for tr?
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
self.device = torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
self.device = torch.device("cpu")
#self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
else:
ocr_model = load_model(ocr_model_dir, compile=False)
self.models["ocr"] = tf.keras.models.Model(
ocr_model.get_layer(name = "image").input,
ocr_model.get_layer(name = "dense2").output)
with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file:
characters = json.load(config_file)
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
self.num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
loadable.append('ocr')
loadable.append('num_to_char')
self.models = self.model_zoo.load_models(*loadable)
def __del__(self):
if hasattr(self, 'executor') and getattr(self, 'executor'):
@ -4261,7 +4124,7 @@ class Eynollah:
gc.collect()
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)),
self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True)
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], textline_light=True)
else:
ocr_all_textlines = None
@ -4770,27 +4633,27 @@ class Eynollah:
if len(all_found_textline_polygons):
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, all_box_coord,
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
if len(all_found_textline_polygons_marginals_left):
ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left,
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
if len(all_found_textline_polygons_marginals_right):
ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right,
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
if self.full_layout and len(all_found_textline_polygons):
ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_h, all_box_coord_h,
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
if self.full_layout and len(polygons_of_drop_capitals):
ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines(
image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)),
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
else:
if self.light_version:
@ -4839,7 +4702,7 @@ class Eynollah:
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
text_ocr = self.return_ocr_of_textline_without_common_section(
img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot)
img_croped, self.models["ocr"], self.models['ocr_tr_processor'], self.device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr)
ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion)

View file

@ -22,7 +22,7 @@ from .utils import (
is_image_filename,
crop_image_inside_box
)
from .eynollah import PatchEncoder, Patches
from .patch_encoder import PatchEncoder, Patches
DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8)

View file

@ -23,7 +23,7 @@ from .utils.contour import (
return_parent_contours,
)
from .utils import is_xml_filename
from .eynollah import PatchEncoder, Patches
from .patch_encoder import PatchEncoder, Patches
DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8)

260
src/eynollah/model_zoo.py Normal file
View file

@ -0,0 +1,260 @@
from dataclasses import dataclass
import json
import logging
from pathlib import Path
from types import MappingProxyType
from typing import Dict, Literal, Optional, Tuple, List, Union
from copy import deepcopy
from keras.layers import StringLookup
from keras.models import Model, load_model
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from eynollah.patch_encoder import PatchEncoder, Patches
# Dict mapping model_category to dict mapping variant (default is '') to Path
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
"enhancement": {
'': "eynollah-enhancement_20210425"
},
"binarization": {
'': "eynollah-binarization_20210425"
},
"col_classifier": {
'': "eynollah-column-classifier_20210425",
},
"page": {
'': "model_eynollah_page_extraction_20250915",
},
# TODO: What is this commented out model?
#?: "eynollah-main-regions-aug-scaling_20210425",
# early layout
"region": {
'': "eynollah-main-regions-ensembled_20210425",
'extract_only_images': "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
'light': "eynollah-main-regions_20220314",
},
# early layout, non-light, 2nd part
"region_p2": {
'': "eynollah-main-regions-aug-rotation_20210425",
},
# early layout, light, 1-or-2-column
"region_1_2": {
#'': "modelens_12sp_elay_0_3_4__3_6_n"
#'': "modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#'': "modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#'': "modelens_1_2_4_5_early_lay_1_2_spaltige"
#'': "model_3_eraly_layout_no_patches_1_2_spaltige"
'': "modelens_e_l_all_sp_0_1_2_3_4_171024"
},
# full layout / no patches
"region_fl_np": {
#'': "modelens_full_lay_1_3_031124"
#'': "modelens_full_lay_13__3_19_241024"
#'': "model_full_lay_13_241024"
#'': "modelens_full_lay_13_17_231024"
#'': "modelens_full_lay_1_2_221024"
#'': "eynollah-full-regions-1column_20210425"
'': "modelens_full_lay_1__4_3_091124"
},
# full layout / with patches
"region_fl": {
#'': "eynollah-full-regions-3+column_20210425"
#'': #"model_2_full_layout_new_trans"
#'': "modelens_full_lay_1_3_031124"
#'': "modelens_full_lay_13__3_19_241024"
#'': "model_full_lay_13_241024"
#'': "modelens_full_lay_13_17_231024"
#'': "modelens_full_lay_1_2_221024"
#'': "modelens_full_layout_24_till_28"
#'': "model_2_full_layout_new_trans"
'': "modelens_full_lay_1__4_3_091124",
},
"reading_order": {
#'': "model_mb_ro_aug_ens_11"
#'': "model_step_3200000_mb_ro"
#'': "model_ens_reading_order_machine_based"
#'': "model_mb_ro_aug_ens_8"
#'': "model_ens_reading_order_machine_based"
'': "model_eynollah_reading_order_20250824"
},
"textline": {
#'light': "eynollah-textline_light_20210425"
'light': "modelens_textline_0_1__2_4_16092024",
#'': "modelens_textline_1_4_16092024"
#'': "model_textline_ens_3_4_5_6_artificial"
#'': "modelens_textline_1_3_4_20240915"
#'': "model_textline_ens_3_4_5_6_artificial"
#'': "modelens_textline_9_12_13_14_15"
#'': "eynollah-textline_20210425"
'': "modelens_textline_0_1__2_4_16092024"
},
"table": {
'light': "modelens_table_0t4_201124",
'': "eynollah-tables_20210319",
},
"ocr": {
'tr': "model_eynollah_ocr_trocr_20250919",
'': "model_eynollah_ocr_cnnrnn_20250930",
},
'ocr_tr_processor': {
'': 'microsoft/trocr-base-printed',
'htr': "microsoft/trocr-base-handwritten",
},
'num_to_char': {
'': 'model_eynollah_ocr_cnnrnn_20250930/characters_org.txt'
},
}
class EynollahModelZoo():
"""
Wrapper class that handles storage and loading of models for all eynollah runners.
"""
model_basedir: Path
model_versions: dict
def __init__(
self,
basedir: str,
model_overrides: List[Tuple[str, str, str]],
) -> None:
self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo')
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
if model_overrides:
self.override_models(*model_overrides)
def override_models(self, *model_overrides: Tuple[str, str, str]):
"""
Override the default model versions
"""
for model_category, model_variant, model_filename in model_overrides:
if model_category not in DEFAULT_MODEL_VERSIONS:
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
self.logger.warning(
"Overriding default model %s ('%s' variant) from %s to %s",
model_category,
model_variant,
DEFAULT_MODEL_VERSIONS[model_category][model_variant],
model_filename
)
self.model_versions[model_category][model_variant] = model_filename
def model_path(
self,
model_category: str,
model_variant: str = '',
model_filename: str = '',
absolute: bool = True,
) -> Path:
"""
Translate model_{type,variant,filename} tuple into an absolute (or relative) Path
"""
if model_category not in DEFAULT_MODEL_VERSIONS:
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
if not model_filename:
model_filename = DEFAULT_MODEL_VERSIONS[model_category][model_variant]
if not Path(model_filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(model_filename)
else:
model_path = Path(model_filename)
return model_path
def load_models(
self,
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
) -> Dict:
"""
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
"""
ret = {}
for load_args in all_load_args:
if isinstance(load_args, str):
ret[load_args] = self.load_model(load_args)
else:
ret[load_args[0]] = self.load_model(*load_args)
return ret
def load_model(
self,
model_category: str,
model_variant: str = '',
model_filename: str = '',
) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
"""
Load any model
"""
model_path = self.model_path(model_category, model_variant, model_filename)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
if model_category == 'ocr':
model = self._load_ocr_model(variant=model_variant)
elif model_category == 'num_to_char':
model = self._load_num_to_char()
elif model_category == 'tr_processor':
return TrOCRProcessor.from_pretrained(self.model_path(...))
else:
try:
model = load_model(model_path, compile=False)
except Exception as e:
self.logger.exception(e)
model = load_model(model_path, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model # type: ignore
def _load_ocr_model(self, variant: str) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
"""
Load OCR model
"""
ocr_model_dir = Path(self.model_basedir, self.model_versions["ocr"][variant])
if variant == 'tr':
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
else:
ocr_model = load_model(ocr_model_dir, compile=False)
assert isinstance(ocr_model, Model)
return Model(
ocr_model.get_layer(name = "image").input, # type: ignore
ocr_model.get_layer(name = "dense2").output) # type: ignore
def _load_num_to_char(self):
"""
Load decoder for OCR
"""
with open(self.model_path('ocr') / self.model_path('ocr', 'num_to_char', absolute=False), "r") as config_file:
characters = json.load(config_file)
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
return StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
def __str__(self):
return str(json.dumps({
'basedir': str(self.model_basedir),
'versions': self.model_versions,
}, indent=2))

View file

@ -0,0 +1,52 @@
from keras import layers
import tensorflow as tf
projection_dim = 64
patch_size = 1
num_patches =21*21#14*14#28*28#14*14#28*28
class PatchEncoder(layers.Layer):
def __init__(self):
super().__init__()
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
def call(self, patch):
positions = tf.range(start=0, limit=num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config