mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
factor model loading in Eynollah to EynollahModelZoo
This commit is contained in:
parent
38c028c6b5
commit
a850ef39ea
6 changed files with 389 additions and 171 deletions
|
|
@ -1,16 +1,57 @@
|
|||
from dataclasses import dataclass
|
||||
import sys
|
||||
import os
|
||||
import click
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
from ocrd_utils import initLogging, getLevelName, getLogger
|
||||
from eynollah.eynollah import Eynollah, Eynollah_ocr
|
||||
from eynollah.sbb_binarize import SbbBinarizer
|
||||
from eynollah.image_enhancer import Enhancer
|
||||
from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout
|
||||
from eynollah.model_zoo import EynollahModelZoo
|
||||
|
||||
@dataclass
|
||||
class EynollahCliCtx():
|
||||
model_basedir: str
|
||||
model_overrides: List[Tuple[str, str, str]]
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
pass
|
||||
|
||||
@main.command('list-models')
|
||||
@click.option(
|
||||
"--model",
|
||||
"-m",
|
||||
'model_basedir',
|
||||
help="directory of models",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--model-overrides",
|
||||
"-mv",
|
||||
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
|
||||
type=(str, str, str),
|
||||
multiple=True,
|
||||
)
|
||||
@click.pass_context
|
||||
def list_models(
|
||||
ctx,
|
||||
model_basedir: str,
|
||||
model_overrides: List[Tuple[str, str, str]],
|
||||
):
|
||||
"""
|
||||
List all the models in the zoo
|
||||
"""
|
||||
ctx.obj = EynollahCliCtx(
|
||||
model_basedir=model_basedir,
|
||||
model_overrides=model_overrides
|
||||
)
|
||||
print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides))
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"--input",
|
||||
|
|
@ -198,15 +239,17 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
|
|||
@click.option(
|
||||
"--model",
|
||||
"-m",
|
||||
'model_basedir',
|
||||
help="directory of models",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--model_version",
|
||||
"-mv",
|
||||
help="override default versions of model categories",
|
||||
type=(str, str),
|
||||
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
|
||||
type=(str, str, str),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
|
|
@ -411,7 +454,7 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav
|
|||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
eynollah = Eynollah(
|
||||
model,
|
||||
model_versions=model_version,
|
||||
model_overrides=model_version,
|
||||
extract_only_images=extract_only_images,
|
||||
enable_plotting=enable_plotting,
|
||||
allow_enhancement=allow_enhancement,
|
||||
|
|
|
|||
|
|
@ -2,12 +2,15 @@
|
|||
# pylint: disable=too-many-locals,wrong-import-position,too-many-lines,too-many-statements,chained-comparison,fixme,broad-except,c-extension-no-member
|
||||
# pylint: disable=too-many-public-methods,too-many-arguments,too-many-instance-attributes,too-many-public-methods,
|
||||
# pylint: disable=consider-using-enumerate
|
||||
# pyright: reportUnnecessaryTypeIgnoreComment=true
|
||||
# pyright: reportPossiblyUnboundVariable=false
|
||||
"""
|
||||
document layout analysis (segmentation) with output in PAGE-XML
|
||||
"""
|
||||
|
||||
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
import importlib_resources
|
||||
else:
|
||||
|
|
@ -19,7 +22,7 @@ import math
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, Union,List, Optional, Tuple
|
||||
import atexit
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
|
@ -58,8 +61,7 @@ except ImportError:
|
|||
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.keras.models import load_model
|
||||
from keras.models import load_model
|
||||
tf.get_logger().setLevel("ERROR")
|
||||
warnings.filterwarnings("ignore")
|
||||
# use tf1 compatibility for keras backend
|
||||
|
|
@ -67,6 +69,7 @@ from tensorflow.compat.v1.keras.backend import set_session
|
|||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.contour import (
|
||||
filter_contours_area_of_image,
|
||||
filter_contours_area_of_image_tables,
|
||||
|
|
@ -155,59 +158,12 @@ patch_size = 1
|
|||
num_patches =21*21#14*14#28*28#14*14#28*28
|
||||
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(PatchEncoder, self).__init__()
|
||||
self.num_patches = num_patches
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(
|
||||
input_dim=num_patches, output_dim=projection_dim
|
||||
)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': self.num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
class Eynollah:
|
||||
def __init__(
|
||||
self,
|
||||
dir_models : str,
|
||||
model_versions: List[Tuple[str, str]] = [],
|
||||
model_overrides: List[Tuple[str, str, str]] = [],
|
||||
extract_only_images : bool =False,
|
||||
enable_plotting : bool = False,
|
||||
allow_enhancement : bool = False,
|
||||
|
|
@ -232,6 +188,7 @@ class Eynollah:
|
|||
skip_layout_and_reading_order : bool = False,
|
||||
):
|
||||
self.logger = getLogger('eynollah')
|
||||
self.model_zoo = EynollahModelZoo(basedir=dir_models)
|
||||
self.plotter = None
|
||||
|
||||
if skip_layout_and_reading_order:
|
||||
|
|
@ -297,93 +254,13 @@ class Eynollah:
|
|||
self.logger.warning("no GPU device available")
|
||||
|
||||
self.logger.info("Loading models...")
|
||||
self.setup_models(dir_models, model_versions)
|
||||
self.setup_models(*model_overrides)
|
||||
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
||||
|
||||
@staticmethod
|
||||
def our_load_model(model_file, basedir=""):
|
||||
if basedir:
|
||||
model_file = os.path.join(basedir, model_file)
|
||||
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_file = model_file[:-3]
|
||||
try:
|
||||
model = load_model(model_file, compile=False)
|
||||
except:
|
||||
model = load_model(model_file, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
return model
|
||||
|
||||
def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []):
|
||||
self.model_versions = {
|
||||
"enhancement": "eynollah-enhancement_20210425",
|
||||
"binarization": "eynollah-binarization_20210425",
|
||||
"col_classifier": "eynollah-column-classifier_20210425",
|
||||
"page": "model_eynollah_page_extraction_20250915",
|
||||
#?: "eynollah-main-regions-aug-scaling_20210425",
|
||||
"region": ( # early layout
|
||||
"eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else
|
||||
"eynollah-main-regions_20220314" if self.light_version else
|
||||
"eynollah-main-regions-ensembled_20210425"),
|
||||
"region_p2": ( # early layout, non-light, 2nd part
|
||||
"eynollah-main-regions-aug-rotation_20210425"),
|
||||
"region_1_2": ( # early layout, light, 1-or-2-column
|
||||
#"modelens_12sp_elay_0_3_4__3_6_n"
|
||||
#"modelens_earlylayout_12spaltige_2_3_5_6_7_8"
|
||||
#"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
|
||||
#"modelens_1_2_4_5_early_lay_1_2_spaltige"
|
||||
#"model_3_eraly_layout_no_patches_1_2_spaltige"
|
||||
"modelens_e_l_all_sp_0_1_2_3_4_171024"),
|
||||
"region_fl_np": ( # full layout / no patches
|
||||
#"modelens_full_lay_1_3_031124"
|
||||
#"modelens_full_lay_13__3_19_241024"
|
||||
#"model_full_lay_13_241024"
|
||||
#"modelens_full_lay_13_17_231024"
|
||||
#"modelens_full_lay_1_2_221024"
|
||||
#"eynollah-full-regions-1column_20210425"
|
||||
"modelens_full_lay_1__4_3_091124"),
|
||||
"region_fl": ( # full layout / with patches
|
||||
#"eynollah-full-regions-3+column_20210425"
|
||||
##"model_2_full_layout_new_trans"
|
||||
#"modelens_full_lay_1_3_031124"
|
||||
#"modelens_full_lay_13__3_19_241024"
|
||||
#"model_full_lay_13_241024"
|
||||
#"modelens_full_lay_13_17_231024"
|
||||
#"modelens_full_lay_1_2_221024"
|
||||
#"modelens_full_layout_24_till_28"
|
||||
#"model_2_full_layout_new_trans"
|
||||
"modelens_full_lay_1__4_3_091124"),
|
||||
"reading_order": (
|
||||
#"model_mb_ro_aug_ens_11"
|
||||
#"model_step_3200000_mb_ro"
|
||||
#"model_ens_reading_order_machine_based"
|
||||
#"model_mb_ro_aug_ens_8"
|
||||
#"model_ens_reading_order_machine_based"
|
||||
"model_eynollah_reading_order_20250824"),
|
||||
"textline": (
|
||||
#"modelens_textline_1_4_16092024"
|
||||
#"model_textline_ens_3_4_5_6_artificial"
|
||||
#"modelens_textline_1_3_4_20240915"
|
||||
#"model_textline_ens_3_4_5_6_artificial"
|
||||
#"modelens_textline_9_12_13_14_15"
|
||||
#"eynollah-textline_light_20210425"
|
||||
"modelens_textline_0_1__2_4_16092024" if self.textline_light else
|
||||
#"eynollah-textline_20210425"
|
||||
"modelens_textline_0_1__2_4_16092024"),
|
||||
"table": (
|
||||
None if not self.tables else
|
||||
"modelens_table_0t4_201124" if self.light_version else
|
||||
"eynollah-tables_20210319"),
|
||||
"ocr": (
|
||||
None if not self.ocr else
|
||||
"model_eynollah_ocr_trocr_20250919" if self.tr else
|
||||
"model_eynollah_ocr_cnnrnn_20250930")
|
||||
}
|
||||
def setup_models(self, *model_overrides: Tuple[str, str, str]):
|
||||
# override defaults from CLI
|
||||
for key, val in model_versions:
|
||||
assert key in self.model_versions, "unknown model category '%s'" % key
|
||||
self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val)
|
||||
self.model_versions[key] = val
|
||||
self.model_zoo.override_models(*model_overrides)
|
||||
|
||||
# load models, depending on modes
|
||||
# (note: loading too many models can cause OOM on GPU/CUDA,
|
||||
# thus, we try set up the minimal configuration for the current mode)
|
||||
|
|
@ -391,10 +268,10 @@ class Eynollah:
|
|||
"col_classifier",
|
||||
"binarization",
|
||||
"page",
|
||||
"region"
|
||||
("region", 'extract_only_images' if self.extract_only_images else 'light' if self.light_version else '')
|
||||
]
|
||||
if not self.extract_only_images:
|
||||
loadable.append("textline")
|
||||
loadable.append(("textline", 'light' if self.light_version else ''))
|
||||
if self.light_version:
|
||||
loadable.append("region_1_2")
|
||||
else:
|
||||
|
|
@ -407,38 +284,24 @@ class Eynollah:
|
|||
if self.reading_order_machine_based:
|
||||
loadable.append("reading_order")
|
||||
if self.tables:
|
||||
loadable.append("table")
|
||||
|
||||
self.models = {name: self.our_load_model(self.model_versions[name], basedir)
|
||||
for name in loadable
|
||||
}
|
||||
loadable.append(("table", 'light' if self.light_version else ''))
|
||||
|
||||
if self.ocr:
|
||||
ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"])
|
||||
if self.tr:
|
||||
self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
||||
loadable.append(('ocr', 'tr'))
|
||||
loadable.append(('ocr_tr_processor', 'tr'))
|
||||
# TODO why here and why only for tr?
|
||||
if torch.cuda.is_available():
|
||||
self.logger.info("Using GPU acceleration")
|
||||
self.device = torch.device("cuda:0")
|
||||
else:
|
||||
self.logger.info("Using CPU processing")
|
||||
self.device = torch.device("cpu")
|
||||
#self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||
else:
|
||||
ocr_model = load_model(ocr_model_dir, compile=False)
|
||||
self.models["ocr"] = tf.keras.models.Model(
|
||||
ocr_model.get_layer(name = "image").input,
|
||||
ocr_model.get_layer(name = "dense2").output)
|
||||
|
||||
with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file:
|
||||
characters = json.load(config_file)
|
||||
# Mapping characters to integers.
|
||||
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
||||
# Mapping integers back to original characters.
|
||||
self.num_to_char = StringLookup(
|
||||
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
||||
)
|
||||
loadable.append('ocr')
|
||||
loadable.append('num_to_char')
|
||||
|
||||
self.models = self.model_zoo.load_models(*loadable)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'executor') and getattr(self, 'executor'):
|
||||
|
|
@ -4261,7 +4124,7 @@ class Eynollah:
|
|||
gc.collect()
|
||||
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)),
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True)
|
||||
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], textline_light=True)
|
||||
else:
|
||||
ocr_all_textlines = None
|
||||
|
||||
|
|
@ -4770,27 +4633,27 @@ class Eynollah:
|
|||
if len(all_found_textline_polygons):
|
||||
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons, all_box_coord,
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
|
||||
|
||||
if len(all_found_textline_polygons_marginals_left):
|
||||
ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left,
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
|
||||
|
||||
if len(all_found_textline_polygons_marginals_right):
|
||||
ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right,
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
|
||||
|
||||
if self.full_layout and len(all_found_textline_polygons):
|
||||
ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons_h, all_box_coord_h,
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
|
||||
|
||||
if self.full_layout and len(polygons_of_drop_capitals):
|
||||
ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)),
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.models["num_to_char"], self.textline_light, self.curved_line)
|
||||
|
||||
else:
|
||||
if self.light_version:
|
||||
|
|
@ -4839,7 +4702,7 @@ class Eynollah:
|
|||
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
|
||||
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
|
||||
text_ocr = self.return_ocr_of_textline_without_common_section(
|
||||
img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot)
|
||||
img_croped, self.models["ocr"], self.models['ocr_tr_processor'], self.device, w, h2w_ratio, ind_tot)
|
||||
ocr_textline_in_textregion.append(text_ocr)
|
||||
ind_tot = ind_tot +1
|
||||
ocr_all_textlines.append(ocr_textline_in_textregion)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from .utils import (
|
|||
is_image_filename,
|
||||
crop_image_inside_box
|
||||
)
|
||||
from .eynollah import PatchEncoder, Patches
|
||||
from .patch_encoder import PatchEncoder, Patches
|
||||
|
||||
DPI_THRESHOLD = 298
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from .utils.contour import (
|
|||
return_parent_contours,
|
||||
)
|
||||
from .utils import is_xml_filename
|
||||
from .eynollah import PatchEncoder, Patches
|
||||
from .patch_encoder import PatchEncoder, Patches
|
||||
|
||||
DPI_THRESHOLD = 298
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
|
|
|||
260
src/eynollah/model_zoo.py
Normal file
260
src/eynollah/model_zoo.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
from dataclasses import dataclass
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
from typing import Dict, Literal, Optional, Tuple, List, Union
|
||||
from copy import deepcopy
|
||||
|
||||
from keras.layers import StringLookup
|
||||
from keras.models import Model, load_model
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
from eynollah.patch_encoder import PatchEncoder, Patches
|
||||
|
||||
|
||||
# Dict mapping model_category to dict mapping variant (default is '') to Path
|
||||
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
|
||||
|
||||
"enhancement": {
|
||||
'': "eynollah-enhancement_20210425"
|
||||
},
|
||||
|
||||
"binarization": {
|
||||
'': "eynollah-binarization_20210425"
|
||||
},
|
||||
|
||||
"col_classifier": {
|
||||
'': "eynollah-column-classifier_20210425",
|
||||
},
|
||||
|
||||
"page": {
|
||||
'': "model_eynollah_page_extraction_20250915",
|
||||
},
|
||||
|
||||
# TODO: What is this commented out model?
|
||||
#?: "eynollah-main-regions-aug-scaling_20210425",
|
||||
|
||||
# early layout
|
||||
"region": {
|
||||
'': "eynollah-main-regions-ensembled_20210425",
|
||||
'extract_only_images': "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
|
||||
'light': "eynollah-main-regions_20220314",
|
||||
},
|
||||
|
||||
# early layout, non-light, 2nd part
|
||||
"region_p2": {
|
||||
'': "eynollah-main-regions-aug-rotation_20210425",
|
||||
},
|
||||
|
||||
# early layout, light, 1-or-2-column
|
||||
"region_1_2": {
|
||||
#'': "modelens_12sp_elay_0_3_4__3_6_n"
|
||||
#'': "modelens_earlylayout_12spaltige_2_3_5_6_7_8"
|
||||
#'': "modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
|
||||
#'': "modelens_1_2_4_5_early_lay_1_2_spaltige"
|
||||
#'': "model_3_eraly_layout_no_patches_1_2_spaltige"
|
||||
'': "modelens_e_l_all_sp_0_1_2_3_4_171024"
|
||||
},
|
||||
|
||||
# full layout / no patches
|
||||
"region_fl_np": {
|
||||
#'': "modelens_full_lay_1_3_031124"
|
||||
#'': "modelens_full_lay_13__3_19_241024"
|
||||
#'': "model_full_lay_13_241024"
|
||||
#'': "modelens_full_lay_13_17_231024"
|
||||
#'': "modelens_full_lay_1_2_221024"
|
||||
#'': "eynollah-full-regions-1column_20210425"
|
||||
'': "modelens_full_lay_1__4_3_091124"
|
||||
},
|
||||
|
||||
# full layout / with patches
|
||||
"region_fl": {
|
||||
#'': "eynollah-full-regions-3+column_20210425"
|
||||
#'': #"model_2_full_layout_new_trans"
|
||||
#'': "modelens_full_lay_1_3_031124"
|
||||
#'': "modelens_full_lay_13__3_19_241024"
|
||||
#'': "model_full_lay_13_241024"
|
||||
#'': "modelens_full_lay_13_17_231024"
|
||||
#'': "modelens_full_lay_1_2_221024"
|
||||
#'': "modelens_full_layout_24_till_28"
|
||||
#'': "model_2_full_layout_new_trans"
|
||||
'': "modelens_full_lay_1__4_3_091124",
|
||||
},
|
||||
|
||||
"reading_order": {
|
||||
#'': "model_mb_ro_aug_ens_11"
|
||||
#'': "model_step_3200000_mb_ro"
|
||||
#'': "model_ens_reading_order_machine_based"
|
||||
#'': "model_mb_ro_aug_ens_8"
|
||||
#'': "model_ens_reading_order_machine_based"
|
||||
'': "model_eynollah_reading_order_20250824"
|
||||
},
|
||||
|
||||
"textline": {
|
||||
#'light': "eynollah-textline_light_20210425"
|
||||
'light': "modelens_textline_0_1__2_4_16092024",
|
||||
#'': "modelens_textline_1_4_16092024"
|
||||
#'': "model_textline_ens_3_4_5_6_artificial"
|
||||
#'': "modelens_textline_1_3_4_20240915"
|
||||
#'': "model_textline_ens_3_4_5_6_artificial"
|
||||
#'': "modelens_textline_9_12_13_14_15"
|
||||
#'': "eynollah-textline_20210425"
|
||||
'': "modelens_textline_0_1__2_4_16092024"
|
||||
},
|
||||
|
||||
"table": {
|
||||
'light': "modelens_table_0t4_201124",
|
||||
'': "eynollah-tables_20210319",
|
||||
},
|
||||
|
||||
"ocr": {
|
||||
'tr': "model_eynollah_ocr_trocr_20250919",
|
||||
'': "model_eynollah_ocr_cnnrnn_20250930",
|
||||
},
|
||||
|
||||
'ocr_tr_processor': {
|
||||
'': 'microsoft/trocr-base-printed',
|
||||
'htr': "microsoft/trocr-base-handwritten",
|
||||
},
|
||||
|
||||
'num_to_char': {
|
||||
'': 'model_eynollah_ocr_cnnrnn_20250930/characters_org.txt'
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class EynollahModelZoo():
|
||||
"""
|
||||
Wrapper class that handles storage and loading of models for all eynollah runners.
|
||||
"""
|
||||
model_basedir: Path
|
||||
model_versions: dict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
basedir: str,
|
||||
model_overrides: List[Tuple[str, str, str]],
|
||||
) -> None:
|
||||
self.model_basedir = Path(basedir)
|
||||
self.logger = logging.getLogger('eynollah.model_zoo')
|
||||
self.model_versions = deepcopy(DEFAULT_MODEL_VERSIONS)
|
||||
if model_overrides:
|
||||
self.override_models(*model_overrides)
|
||||
|
||||
def override_models(self, *model_overrides: Tuple[str, str, str]):
|
||||
"""
|
||||
Override the default model versions
|
||||
"""
|
||||
for model_category, model_variant, model_filename in model_overrides:
|
||||
if model_category not in DEFAULT_MODEL_VERSIONS:
|
||||
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
|
||||
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
|
||||
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
|
||||
self.logger.warning(
|
||||
"Overriding default model %s ('%s' variant) from %s to %s",
|
||||
model_category,
|
||||
model_variant,
|
||||
DEFAULT_MODEL_VERSIONS[model_category][model_variant],
|
||||
model_filename
|
||||
)
|
||||
self.model_versions[model_category][model_variant] = model_filename
|
||||
|
||||
def model_path(
|
||||
self,
|
||||
model_category: str,
|
||||
model_variant: str = '',
|
||||
model_filename: str = '',
|
||||
absolute: bool = True,
|
||||
) -> Path:
|
||||
"""
|
||||
Translate model_{type,variant,filename} tuple into an absolute (or relative) Path
|
||||
"""
|
||||
if model_category not in DEFAULT_MODEL_VERSIONS:
|
||||
raise ValueError(f"Unknown model_category '{model_category}', must be one of {DEFAULT_MODEL_VERSIONS.keys()}")
|
||||
if model_variant not in DEFAULT_MODEL_VERSIONS[model_category]:
|
||||
raise ValueError(f"Unknown variant {model_variant} for {model_category}. Known variants: {DEFAULT_MODEL_VERSIONS[model_category].keys()}")
|
||||
if not model_filename:
|
||||
model_filename = DEFAULT_MODEL_VERSIONS[model_category][model_variant]
|
||||
if not Path(model_filename).is_absolute() and absolute:
|
||||
model_path = Path(self.model_basedir).joinpath(model_filename)
|
||||
else:
|
||||
model_path = Path(model_filename)
|
||||
return model_path
|
||||
|
||||
def load_models(
|
||||
self,
|
||||
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
|
||||
) -> Dict:
|
||||
"""
|
||||
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
|
||||
"""
|
||||
ret = {}
|
||||
for load_args in all_load_args:
|
||||
if isinstance(load_args, str):
|
||||
ret[load_args] = self.load_model(load_args)
|
||||
else:
|
||||
ret[load_args[0]] = self.load_model(*load_args)
|
||||
return ret
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_category: str,
|
||||
model_variant: str = '',
|
||||
model_filename: str = '',
|
||||
) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
|
||||
"""
|
||||
Load any model
|
||||
"""
|
||||
model_path = self.model_path(model_category, model_variant, model_filename)
|
||||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_path = Path(model_path.stem)
|
||||
if model_category == 'ocr':
|
||||
model = self._load_ocr_model(variant=model_variant)
|
||||
elif model_category == 'num_to_char':
|
||||
model = self._load_num_to_char()
|
||||
elif model_category == 'tr_processor':
|
||||
return TrOCRProcessor.from_pretrained(self.model_path(...))
|
||||
else:
|
||||
try:
|
||||
model = load_model(model_path, compile=False)
|
||||
except Exception as e:
|
||||
self.logger.exception(e)
|
||||
model = load_model(model_path, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
return model # type: ignore
|
||||
|
||||
def _load_ocr_model(self, variant: str) -> Union[VisionEncoderDecoderModel, TrOCRProcessor, Model]:
|
||||
"""
|
||||
Load OCR model
|
||||
"""
|
||||
ocr_model_dir = Path(self.model_basedir, self.model_versions["ocr"][variant])
|
||||
if variant == 'tr':
|
||||
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
||||
else:
|
||||
ocr_model = load_model(ocr_model_dir, compile=False)
|
||||
assert isinstance(ocr_model, Model)
|
||||
return Model(
|
||||
ocr_model.get_layer(name = "image").input, # type: ignore
|
||||
ocr_model.get_layer(name = "dense2").output) # type: ignore
|
||||
|
||||
def _load_num_to_char(self):
|
||||
"""
|
||||
Load decoder for OCR
|
||||
"""
|
||||
with open(self.model_path('ocr') / self.model_path('ocr', 'num_to_char', absolute=False), "r") as config_file:
|
||||
characters = json.load(config_file)
|
||||
# Mapping characters to integers.
|
||||
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
||||
# Mapping integers back to original characters.
|
||||
return StringLookup(
|
||||
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return str(json.dumps({
|
||||
'basedir': str(self.model_basedir),
|
||||
'versions': self.model_versions,
|
||||
}, indent=2))
|
||||
|
||||
52
src/eynollah/patch_encoder.py
Normal file
52
src/eynollah/patch_encoder.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from keras import layers
|
||||
import tensorflow as tf
|
||||
|
||||
projection_dim = 64
|
||||
patch_size = 1
|
||||
num_patches =21*21#14*14#28*28#14*14#28*28
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
Loading…
Add table
Add a link
Reference in a new issue