diff --git a/Makefile b/Makefile index 29dd877..b1cbcc4 100644 --- a/Makefile +++ b/Makefile @@ -6,21 +6,23 @@ EXTRAS ?= DOCKER_BASE_IMAGE ?= docker.io/ocrd/core-cuda-tf2:latest DOCKER_TAG ?= ocrd/eynollah DOCKER ?= docker +WGET = wget -O #SEG_MODEL := https://qurator-data.de/eynollah/2021-04-25/models_eynollah.tar.gz #SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah_renamed.tar.gz # SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah.tar.gz #SEG_MODEL := https://github.com/qurator-spk/eynollah/releases/download/v0.3.0/models_eynollah.tar.gz #SEG_MODEL := https://github.com/qurator-spk/eynollah/releases/download/v0.3.1/models_eynollah.tar.gz -SEG_MODEL := https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1 +#SEG_MODEL := https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1 +SEG_MODEL := https://zenodo.org/records/17295988/files/models_layout_v0_6_0.tar.gz?download=1 SEG_MODELFILE = $(notdir $(patsubst %?download=1,%,$(SEG_MODEL))) SEG_MODELNAME = $(SEG_MODELFILE:%.tar.gz=%) -BIN_MODEL := https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2021_03_09.zip +BIN_MODEL := https://zenodo.org/records/17295988/files/models_binarization_v0_6_0.tar.gz?download=1 BIN_MODELFILE = $(notdir $(BIN_MODEL)) BIN_MODELNAME := default-2021-03-09 -OCR_MODEL := https://zenodo.org/records/17236998/files/models_ocr_v0_5_1.tar.gz?download=1 +OCR_MODEL := https://zenodo.org/records/17295988/files/models_ocr_v0_6_0.tar.gz?download=1 OCR_MODELFILE = $(notdir $(patsubst %?download=1,%,$(OCR_MODEL))) OCR_MODELNAME = $(OCR_MODELFILE:%.tar.gz=%) @@ -55,22 +57,21 @@ help: # END-EVAL -# Download and extract models to $(PWD)/models_layout_v0_5_0 +# Download and extract models to $(PWD)/models_layout_v0_6_0 models: $(BIN_MODELNAME) $(SEG_MODELNAME) $(OCR_MODELNAME) # do not download these files if we already have the directories .INTERMEDIATE: $(BIN_MODELFILE) $(SEG_MODELFILE) $(OCR_MODELFILE) $(BIN_MODELFILE): - wget -O $@ $(BIN_MODEL) + $(WGET) $@ $(BIN_MODEL) $(SEG_MODELFILE): - wget -O $@ $(SEG_MODEL) + $(WGET) $@ $(SEG_MODEL) $(OCR_MODELFILE): - wget -O $@ $(OCR_MODEL) + $(WGET) $@ $(OCR_MODEL) $(BIN_MODELNAME): $(BIN_MODELFILE) - mkdir $@ - unzip -d $@ $< + tar zxf $< $(SEG_MODELNAME): $(SEG_MODELFILE) tar zxf $< $(OCR_MODELNAME): $(OCR_MODELFILE) diff --git a/README.md b/README.md index 3ba5086..3ecb3d7 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ make install EXTRAS=OCR ## Models -Pretrained models can be downloaded from [zenodo](https://zenodo.org/records/17194824) or [huggingface](https://huggingface.co/SBB?search_models=eynollah). +Pretrained models can be downloaded from [zenodo](https://doi.org/10.5281/zenodo.17194823) or [huggingface](https://huggingface.co/SBB?search_models=eynollah). For documentation on models, have a look at [`models.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/models.md). Model cards are also provided for our trained models. @@ -162,7 +162,7 @@ formally described in [`ocrd-tool.json`](https://github.com/qurator-spk/eynollah In this case, the source image file group with (preferably) RGB images should be used as input like this: - ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models eynollah_layout_v0_5_0 + ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models eynollah_layout_v0_6_0 If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynollah behaves as follows: - existing regions are kept and ignored (i.e. in effect they might overlap segments from Eynollah results) @@ -174,7 +174,7 @@ If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynol (because some other preprocessing step was in effect like `denoised`), then the output PAGE-XML will be based on that as new top-level (`@imageFilename`) - ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_5_0 + ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_6_0 In general, it makes more sense to add other workflow steps **after** Eynollah. diff --git a/requirements.txt b/requirements.txt index db1d7df..bbacd48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ tensorflow < 2.13 numba <= 0.58.1 scikit-image biopython +tabulate diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index c9bad52..595f0ee 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -1,16 +1,24 @@ +from dataclasses import dataclass import sys import click import logging +from typing import Tuple, List from ocrd_utils import initLogging, getLevelName, getLogger -from eynollah.eynollah import Eynollah, Eynollah_ocr +from eynollah.eynollah import Eynollah +from eynollah.eynollah_ocr import Eynollah_ocr from eynollah.sbb_binarize import SbbBinarizer from eynollah.image_enhancer import Enhancer from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout +from eynollah.model_zoo import EynollahModelZoo + +from .cli_models import models_cli @click.group() def main(): pass +main.add_command(models_cli, 'models') + @main.command() @click.option( "--input", @@ -79,18 +87,38 @@ def machine_based_reading_order(input, dir_in, out, model, log_level): type=click.Path(file_okay=True, dir_okay=True), required=True, ) +@click.option( + '-M', + '--mode', + type=click.Choice(['single', 'multi']), + default='single', + help="Whether to use the (faster) single-model binarization or the (slightly better) multi-model binarization" +) @click.option( "--log_level", "-l", type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), help="Override log level globally to this", ) -def binarization(patches, model_dir, input_image, dir_in, output, log_level): +def binarization( + patches, + model_dir, + input_image, + mode, + dir_in, + output, + log_level, +): assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - binarizer = SbbBinarizer(model_dir) + binarizer = SbbBinarizer(model_dir, mode=mode) if log_level: binarizer.log.setLevel(getLevelName(log_level)) - binarizer.run(image_path=input_image, use_patches=patches, output=output, dir_in=dir_in) + binarizer.run( + image_path=input_image, + use_patches=patches, + output=output, + dir_in=dir_in + ) @main.command() @@ -198,15 +226,17 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low @click.option( "--model", "-m", + 'model_basedir', help="directory of models", type=click.Path(exists=True, file_okay=False), + # default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment", required=True, ) @click.option( "--model_version", "-mv", - help="override default versions of model categories", - type=(str, str), + help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list", + type=(str, str, str), multiple=True, ) @click.option( @@ -380,7 +410,43 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low help="Setup a basic console logger", ) -def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging): +def layout( + image, + out, + overwrite, + dir_in, + model_basedir, + model_version, + save_images, + save_layout, + save_deskewed, + save_all, + extract_only_images, + save_page, + enable_plotting, + allow_enhancement, + curved_line, + textline_light, + full_layout, + tables, + right2left, + input_binary, + allow_scaling, + headers_off, + light_version, + reading_order_machine_based, + do_ocr, + transformer_ocr, + batch_size_ocr, + num_col_upper, + num_col_lower, + threshold_art_class_textline, + threshold_art_class_layout, + skip_layout_and_reading_order, + ignore_page_extraction, + log_level, + setup_logging, +): if setup_logging: console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) @@ -410,8 +476,8 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav assert not extract_only_images or not headers_off, "Image extraction -eoi can not be set alongside headers_off -ho" assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." eynollah = Eynollah( - model, - model_versions=model_version, + model_basedir, + model_overrides=model_version, extract_only_images=extract_only_images, enable_plotting=enable_plotting, allow_enhancement=allow_enhancement, diff --git a/src/eynollah/cli_models.py b/src/eynollah/cli_models.py new file mode 100644 index 0000000..595c499 --- /dev/null +++ b/src/eynollah/cli_models.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import List, Set, Tuple +import click + +from eynollah.model_zoo.default_specs import MODELS_VERSION +from .model_zoo import EynollahModelZoo + + +@dataclass() +class EynollahCliCtx: + model_zoo: EynollahModelZoo + + +@click.group() +@click.pass_context +@click.option( + "--model", + "-m", + 'model_basedir', + help="directory of models", + type=click.Path(exists=True, file_okay=False), + # default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment", + required=True, +) +@click.option( + "--model-overrides", + "-mv", + help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list", + type=(str, str, str), + multiple=True, +) +def models_cli( + ctx, + model_basedir: str, + model_overrides: List[Tuple[str, str, str]], +): + """ + Organize models for the various runners in eynollah. + """ + ctx.obj = EynollahCliCtx(model_zoo=EynollahModelZoo(basedir=model_basedir, model_overrides=model_overrides)) + + +@models_cli.command('list') +@click.pass_context +def list_models( + ctx, +): + """ + List all the models in the zoo + """ + print(ctx.obj.model_zoo) + + +@models_cli.command('package') +@click.option( + '--set-version', '-V', 'version', help="Version to use for packaging", default=MODELS_VERSION, show_default=True +) +@click.argument('output_dir') +@click.pass_context +def package( + ctx, + version, + output_dir, +): + """ + Generate shell code to copy all the models in the zoo into properly named folders in OUTPUT_DIR for distribution. + + eynollah models -m SRC package OUTPUT_DIR + + SRC should contain a directory "models_eynollah" containing all the models. + """ + mkdirs: Set[Path] = set([]) + copies: Set[Tuple[Path, Path]] = set([]) + for spec in ctx.obj.model_zoo.specs.specs: + # skip these as they are dependent on the ocr model + if spec.category in ('num_to_char', 'characters'): + continue + src: Path = ctx.obj.model_zoo.model_path(spec.category, spec.variant) + # Only copy the top-most directory relative to models_eynollah + while src.parent.name != 'models_eynollah': + src = src.parent + for dist in spec.dists: + dist_dir = Path(f"{output_dir}/models_{dist}_{version}/models_eynollah") + copies.add((src, dist_dir)) + mkdirs.add(dist_dir) + for dir in mkdirs: + print(f"mkdir -p {dir}") + for (src, dst) in copies: + print(f"cp -r {src} {dst}") + for dir in mkdirs: + zip_path = Path(f'../{dir.parent.name}.zip') + print(f"(cd {dir}/..; zip -r {zip_path} models_eynollah)") diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 13acba6..232631a 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -2,71 +2,50 @@ # pylint: disable=too-many-locals,wrong-import-position,too-many-lines,too-many-statements,chained-comparison,fixme,broad-except,c-extension-no-member # pylint: disable=too-many-public-methods,too-many-arguments,too-many-instance-attributes,too-many-public-methods, # pylint: disable=consider-using-enumerate +# pyright: reportUnnecessaryTypeIgnoreComment=true +# pyright: reportPossiblyUnboundVariable=false """ document layout analysis (segmentation) with output in PAGE-XML """ -# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files -import sys -if sys.version_info < (3, 10): - import importlib_resources -else: - import importlib.resources as importlib_resources - from difflib import SequenceMatcher as sq -from PIL import Image, ImageDraw, ImageFont import math import os -import sys import time -from typing import Dict, List, Optional, Tuple -import atexit +from typing import List, Optional, Tuple import warnings from functools import partial from pathlib import Path from multiprocessing import cpu_count import gc import copy -import json from concurrent.futures import ProcessPoolExecutor -import xml.etree.ElementTree as ET import cv2 import numpy as np import shapely.affinity from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter1d -from numba import cuda from skimage.morphology import skeletonize -from ocrd import OcrdPage from ocrd_utils import getLogger, tf_disable_interactive_logs import statistics try: - import torch + import torch # type: ignore except ImportError: torch = None try: import matplotlib.pyplot as plt except ImportError: plt = None -try: - from transformers import TrOCRProcessor, VisionEncoderDecoderModel -except ImportError: - TrOCRProcessor = VisionEncoderDecoderModel = None #os.environ['CUDA_VISIBLE_DEVICES'] = '-1' tf_disable_interactive_logs() import tensorflow as tf -from tensorflow.python.keras import backend as K -from tensorflow.keras.models import load_model tf.get_logger().setLevel("ERROR") warnings.filterwarnings("ignore") -# use tf1 compatibility for keras backend -from tensorflow.compat.v1.keras.backend import set_session -from tensorflow.keras import layers -from tensorflow.keras.layers import StringLookup +from .model_zoo import EynollahModelZoo from .utils.contour import ( filter_contours_area_of_image, filter_contours_area_of_image_tables, @@ -155,59 +134,12 @@ patch_size = 1 num_patches =21*21#14*14#28*28#14*14#28*28 -class Patches(layers.Layer): - def __init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size - - def call(self, images): - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - patch_dims = patches.shape[-1] - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config - -class PatchEncoder(layers.Layer): - def __init__(self, **kwargs): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config class Eynollah: def __init__( self, dir_models : str, - model_versions: List[Tuple[str, str]] = [], + model_overrides: List[Tuple[str, str, str]] = [], extract_only_images : bool =False, enable_plotting : bool = False, allow_enhancement : bool = False, @@ -232,6 +164,7 @@ class Eynollah: skip_layout_and_reading_order : bool = False, ): self.logger = getLogger('eynollah') + self.model_zoo = EynollahModelZoo(basedir=dir_models) self.plotter = None if skip_layout_and_reading_order: @@ -297,93 +230,13 @@ class Eynollah: self.logger.warning("no GPU device available") self.logger.info("Loading models...") - self.setup_models(dir_models, model_versions) + self.setup_models(*model_overrides) self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") - @staticmethod - def our_load_model(model_file, basedir=""): - if basedir: - model_file = os.path.join(basedir, model_file) - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model - - def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []): - self.model_versions = { - "enhancement": "eynollah-enhancement_20210425", - "binarization": "eynollah-binarization_20210425", - "col_classifier": "eynollah-column-classifier_20210425", - "page": "model_eynollah_page_extraction_20250915", - #?: "eynollah-main-regions-aug-scaling_20210425", - "region": ( # early layout - "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else - "eynollah-main-regions_20220314" if self.light_version else - "eynollah-main-regions-ensembled_20210425"), - "region_p2": ( # early layout, non-light, 2nd part - "eynollah-main-regions-aug-rotation_20210425"), - "region_1_2": ( # early layout, light, 1-or-2-column - #"modelens_12sp_elay_0_3_4__3_6_n" - #"modelens_earlylayout_12spaltige_2_3_5_6_7_8" - #"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18" - #"modelens_1_2_4_5_early_lay_1_2_spaltige" - #"model_3_eraly_layout_no_patches_1_2_spaltige" - "modelens_e_l_all_sp_0_1_2_3_4_171024"), - "region_fl_np": ( # full layout / no patches - #"modelens_full_lay_1_3_031124" - #"modelens_full_lay_13__3_19_241024" - #"model_full_lay_13_241024" - #"modelens_full_lay_13_17_231024" - #"modelens_full_lay_1_2_221024" - #"eynollah-full-regions-1column_20210425" - "modelens_full_lay_1__4_3_091124"), - "region_fl": ( # full layout / with patches - #"eynollah-full-regions-3+column_20210425" - ##"model_2_full_layout_new_trans" - #"modelens_full_lay_1_3_031124" - #"modelens_full_lay_13__3_19_241024" - #"model_full_lay_13_241024" - #"modelens_full_lay_13_17_231024" - #"modelens_full_lay_1_2_221024" - #"modelens_full_layout_24_till_28" - #"model_2_full_layout_new_trans" - "modelens_full_lay_1__4_3_091124"), - "reading_order": ( - #"model_mb_ro_aug_ens_11" - #"model_step_3200000_mb_ro" - #"model_ens_reading_order_machine_based" - #"model_mb_ro_aug_ens_8" - #"model_ens_reading_order_machine_based" - "model_eynollah_reading_order_20250824"), - "textline": ( - #"modelens_textline_1_4_16092024" - #"model_textline_ens_3_4_5_6_artificial" - #"modelens_textline_1_3_4_20240915" - #"model_textline_ens_3_4_5_6_artificial" - #"modelens_textline_9_12_13_14_15" - #"eynollah-textline_light_20210425" - "modelens_textline_0_1__2_4_16092024" if self.textline_light else - #"eynollah-textline_20210425" - "modelens_textline_0_1__2_4_16092024"), - "table": ( - None if not self.tables else - "modelens_table_0t4_201124" if self.light_version else - "eynollah-tables_20210319"), - "ocr": ( - None if not self.ocr else - "model_eynollah_ocr_trocr_20250919" if self.tr else - "model_eynollah_ocr_cnnrnn_20250930") - } + def setup_models(self, *model_overrides: Tuple[str, str, str]): # override defaults from CLI - for key, val in model_versions: - assert key in self.model_versions, "unknown model category '%s'" % key - self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val) - self.model_versions[key] = val + self.model_zoo.override_models(*model_overrides) + # load models, depending on modes # (note: loading too many models can cause OOM on GPU/CUDA, # thus, we try set up the minimal configuration for the current mode) @@ -391,10 +244,10 @@ class Eynollah: "col_classifier", "binarization", "page", - "region" + ("region", 'extract_only_images' if self.extract_only_images else 'light' if self.light_version else '') ] if not self.extract_only_images: - loadable.append("textline") + loadable.append(("textline", 'light' if self.light_version else '')) if self.light_version: loadable.append("region_1_2") else: @@ -407,47 +260,34 @@ class Eynollah: if self.reading_order_machine_based: loadable.append("reading_order") if self.tables: - loadable.append("table") - - self.models = {name: self.our_load_model(self.model_versions[name], basedir) - for name in loadable - } + loadable.append(("table", 'light' if self.light_version else '')) if self.ocr: - ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"]) if self.tr: - self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - self.device = torch.device("cuda:0") - else: - self.logger.info("Using CPU processing") - self.device = torch.device("cpu") - #self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") - self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") + loadable.append(('ocr', 'tr')) + loadable.append(('trocr_processor', '')) else: - ocr_model = load_model(ocr_model_dir, compile=False) - self.models["ocr"] = tf.keras.models.Model( - ocr_model.get_layer(name = "image").input, - ocr_model.get_layer(name = "dense2").output) - - with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file: - characters = json.load(config_file) - # Mapping characters to integers. - char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) - # Mapping integers back to original characters. - self.num_to_char = StringLookup( - vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True - ) + loadable.append('ocr') + loadable.append('num_to_char') + + self.model_zoo.load_models(*loadable) def __del__(self): if hasattr(self, 'executor') and getattr(self, 'executor'): + assert self.executor self.executor.shutdown() self.executor = None - if hasattr(self, 'models') and getattr(self, 'models'): - for model_name in list(self.models): - if self.models[model_name]: - del self.models[model_name] + self.model_zoo.shutdown() + + @property + def device(self): + # TODO why here and why only for tr? + assert torch + if torch.cuda.is_available(): + self.logger.info("Using GPU acceleration") + return torch.device("cuda:0") + self.logger.info("Using CPU processing") + return torch.device("cpu") def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} @@ -494,8 +334,8 @@ class Eynollah: def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - img_height_model = self.models["enhancement"].layers[-1].output_shape[1] - img_width_model = self.models["enhancement"].layers[-1].output_shape[2] + img_height_model = self.model_zoo.get("enhancement").layers[-1].output_shape[1] + img_width_model = self.model_zoo.get("enhancement").layers[-1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) if img.shape[1] < img_width_model: @@ -536,7 +376,7 @@ class Eynollah: index_y_d = img_h - img_height_model img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.models["enhancement"].predict(img_patch, verbose=0) + label_p_pred = self.model_zoo.get("enhancement").predict(img_patch, verbose=0) seg = label_p_pred[0, :, :, :] * 255 if i == 0 and j == 0: @@ -711,7 +551,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 self.logger.info("Found %s columns (%s)", num_col, label_p_pred) @@ -729,7 +569,7 @@ class Eynollah: self.logger.info("Detected %s DPI", dpi) if self.input_binary: img = self.imread() - prediction_bin = self.do_prediction(True, img, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) img= np.copy(prediction_bin) @@ -769,7 +609,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): @@ -790,7 +630,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 if num_col > self.num_col_upper: @@ -845,8 +685,8 @@ class Eynollah: self.img_hight_int = int(self.image.shape[0] * scale) self.img_width_int = int(self.image.shape[1] * scale) - self.scale_y = self.img_hight_int / float(self.image.shape[0]) - self.scale_x = self.img_width_int / float(self.image.shape[1]) + self.scale_y: float = self.img_hight_int / float(self.image.shape[0]) + self.scale_x: float = self.img_width_int / float(self.image.shape[1]) self.image = resize_image(self.image, self.img_hight_int, self.img_width_int) @@ -1642,7 +1482,7 @@ class Eynollah: cont_page = [] if not self.ignore_page_extraction: img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.models["page"]) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get("page")) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) ##thresh = cv2.dilate(thresh, KERNEL, iterations=3) @@ -1690,7 +1530,7 @@ class Eynollah: else: img = self.imread() img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.models["page"]) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get("page")) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -1716,7 +1556,7 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.models["region_fl"] if patches else self.models["region_fl_np"] + model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") if self.light_version: thresholding_for_fl_light_version = True @@ -1751,7 +1591,7 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.models["region_fl"] if patches else self.models["region_fl_np"] + model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") if not patches: img = otsu_copy_binary(img) @@ -1911,6 +1751,7 @@ class Eynollah: return [], [], [] self.logger.debug("enter get_slopes_and_deskew_new_light") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new_light, textline_mask_tot_ea=textline_mask_tot_shared, slope_deskew=slope_deskew, @@ -1927,6 +1768,7 @@ class Eynollah: return [], [], [] self.logger.debug("enter get_slopes_and_deskew_new") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new, textline_mask_tot_ea=textline_mask_tot_shared, slope_deskew=slope_deskew, @@ -1947,6 +1789,7 @@ class Eynollah: self.logger.debug("enter get_slopes_and_deskew_new_curved") with share_ndarray(textline_mask_tot) as textline_mask_tot_shared: with share_ndarray(mask_texts_only) as mask_texts_only_shared: + assert self.executor results = self.executor.map(partial(do_work_of_slopes_new_curved, textline_mask_tot_ea=textline_mask_tot_shared, mask_texts_only=mask_texts_only_shared, @@ -1972,14 +1815,14 @@ class Eynollah: img_w = img_org.shape[1] img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) - prediction_textline = self.do_prediction(use_patches, img, self.models["textline"], + prediction_textline = self.do_prediction(use_patches, img, self.model_zoo.get("textline"), marginal_of_patch_percent=0.15, n_batch_inference=3, thresholding_for_artificial_class_in_light_version=self.textline_light, threshold_art_class_textline=self.threshold_art_class_textline) #if not self.textline_light: #if num_col_classifier==1: - #prediction_textline_nopatch = self.do_prediction(False, img, self.models["textline"]) + #prediction_textline_nopatch = self.do_prediction(False, img, self.model_zoo.get_model("textline")) #prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0 prediction_textline = resize_image(prediction_textline, img_h, img_w) @@ -2050,7 +1893,7 @@ class Eynollah: #cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0]) - prediction_textline_longshot = self.do_prediction(False, img, self.models["textline"]) + prediction_textline_longshot = self.do_prediction(False, img, self.model_zoo.get("textline")) prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w) @@ -2083,7 +1926,7 @@ class Eynollah: img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new) img_resized = resize_image(img,img_h_new, img_w_new ) - prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.models["region"]) + prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_zoo.get("region")) prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h ) image_page, page_coord, cont_page = self.extract_page() @@ -2199,7 +2042,7 @@ class Eynollah: #if self.input_binary: #img_bin = np.copy(img_resized) ###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30): - ###prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5) + ###prediction_bin = self.do_prediction(True, img_resized, self.model_zoo.get_model("binarization"), n_batch_inference=5) ####print("inside bin ", time.time()-t_bin) ###prediction_bin=prediction_bin[:,:,0] @@ -2214,7 +2057,7 @@ class Eynollah: ###else: ###img_bin = np.copy(img_resized) if (self.ocr and self.tr) and not self.input_binary: - prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_resized, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) prediction_bin = prediction_bin.astype(np.uint16) @@ -2246,14 +2089,14 @@ class Eynollah: self.logger.debug("resized to %dx%d for %d cols", img_resized.shape[1], img_resized.shape[0], num_col_classifier) prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.models["region_1_2"], n_batch_inference=1, + True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=1, thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) else: prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1])) prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept( - False, self.image_page_org_size, self.models["region_1_2"], n_batch_inference=1, + False, self.image_page_org_size, self.model_zoo.get("region_1_2"), n_batch_inference=1, thresholding_for_artificial_class_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) ys = slice(*self.page_coord[0:2]) @@ -2267,10 +2110,10 @@ class Eynollah: self.logger.debug("resized to %dx%d (new_h=%d) for %d cols", img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier) prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.models["region_1_2"], n_batch_inference=2, + True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=2, thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) - ###prediction_regions_org = self.do_prediction(True, img_bin, self.models["region"], + ###prediction_regions_org = self.do_prediction(True, img_bin, self.model_zoo.get_model("region"), ###n_batch_inference=3, ###thresholding_for_some_classes_in_light_version=True) #print("inside 3 ", time.time()-t_in) @@ -2350,7 +2193,7 @@ class Eynollah: ratio_x=1 img = resize_image(img_org, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org_y = self.do_prediction(True, img, self.models["region"]) + prediction_regions_org_y = self.do_prediction(True, img, self.model_zoo.get("region")) prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h ) #plt.imshow(prediction_regions_org_y[:,:,0]) @@ -2365,7 +2208,7 @@ class Eynollah: _, _ = find_num_col(img_only_regions, num_col_classifier, self.tables, multiplier=6.0) img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]*(1.2 if is_image_enhanced else 1))) - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) + prediction_regions_org = self.do_prediction(True, img, self.model_zoo.get("region")) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] @@ -2373,7 +2216,7 @@ class Eynollah: img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1])) - prediction_regions_org2 = self.do_prediction(True, img, self.models["region_p2"], marginal_of_patch_percent=0.2) + prediction_regions_org2 = self.do_prediction(True, img, self.model_zoo.get("region_p2"), marginal_of_patch_percent=0.2) prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h ) mask_zeros2 = (prediction_regions_org2[:,:,0] == 0) @@ -2397,7 +2240,7 @@ class Eynollah: if self.input_binary: prediction_bin = np.copy(img_org) else: - prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_org, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) @@ -2407,7 +2250,7 @@ class Eynollah: img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) + prediction_regions_org = self.do_prediction(True, img, self.model_zoo.get("region")) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] @@ -2434,7 +2277,7 @@ class Eynollah: except: if self.input_binary: prediction_bin = np.copy(img_org) - prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_org, self.model_zoo.get("binarization"), n_batch_inference=5) prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) @@ -2445,14 +2288,14 @@ class Eynollah: img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org = self.do_prediction(True, img, self.models["region"]) + prediction_regions_org = self.do_prediction(True, img, self.model_zoo.get("region")) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] #mask_lines_only=(prediction_regions_org[:,:]==3)*1 #img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1)) - #prediction_regions_org = self.do_prediction(True, img, self.models["region"]) + #prediction_regions_org = self.do_prediction(True, img, self.model_zoo.get_model("region")) #prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) #prediction_regions_org = prediction_regions_org[:,:,0] #prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0 @@ -2823,13 +2666,13 @@ class Eynollah: img_width_h = img_org.shape[1] patches = False if self.light_version: - prediction_table, _ = self.do_prediction_new_concept(patches, img, self.models["table"]) + prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_zoo.get("table")) prediction_table = prediction_table.astype(np.int16) return prediction_table[:,:,0] else: if num_col_classifier < 4 and num_col_classifier > 2: - prediction_table = self.do_prediction(patches, img, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"]) + prediction_table = self.do_prediction(patches, img, self.model_zoo.get("table")) + pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_zoo.get("table")) pre_updown = cv2.flip(pre_updown, -1) prediction_table[:,:,0][pre_updown[:,:,0]==1]=1 @@ -2848,8 +2691,8 @@ class Eynollah: xs = slice(w_start, w_start + img.shape[1]) img_new[ys, xs] = img - prediction_ext = self.do_prediction(patches, img_new, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"]) + prediction_ext = self.do_prediction(patches, img_new, self.model_zoo.get("table")) + pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_zoo.get("table")) pre_updown = cv2.flip(pre_updown, -1) prediction_table = prediction_ext[ys, xs] @@ -2870,8 +2713,8 @@ class Eynollah: xs = slice(w_start, w_start + img.shape[1]) img_new[ys, xs] = img - prediction_ext = self.do_prediction(patches, img_new, self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"]) + prediction_ext = self.do_prediction(patches, img_new, self.model_zoo.get("table")) + pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_zoo.get("table")) pre_updown = cv2.flip(pre_updown, -1) prediction_table = prediction_ext[ys, xs] @@ -2883,10 +2726,10 @@ class Eynollah: prediction_table = np.zeros(img.shape) img_w_half = img.shape[1] // 2 - pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.models["table"]) - pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.models["table"]) - pre_full = self.do_prediction(patches, img[:,:,:], self.models["table"]) - pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"]) + pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.model_zoo.get("table")) + pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.model_zoo.get("table")) + pre_full = self.do_prediction(patches, img[:,:,:], self.model_zoo.get("table")) + pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_zoo.get("table")) pre_updown = cv2.flip(pre_updown, -1) prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4) @@ -3678,7 +3521,7 @@ class Eynollah: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.models["reading_order"].predict(input_1 , verbose=0) + y_pr = self.model_zoo.get("reading_order").predict(input_1 , verbose=0) for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) @@ -4261,7 +4104,7 @@ class Eynollah: gc.collect() ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)), - self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), textline_light=True) else: ocr_all_textlines = None @@ -4770,27 +4613,27 @@ class Eynollah: if len(all_found_textline_polygons): ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons, all_box_coord, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) if len(all_found_textline_polygons_marginals_left): ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) if len(all_found_textline_polygons_marginals_right): ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) if self.full_layout and len(all_found_textline_polygons): ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_h, all_box_coord_h, - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) if self.full_layout and len(polygons_of_drop_capitals): ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines( image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)), - self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.model_zoo.get("ocr"), self.b_s_ocr, self.model_zoo.get("num_to_char"), self.textline_light, self.curved_line) else: if self.light_version: @@ -4802,7 +4645,7 @@ class Eynollah: gc.collect() torch.cuda.empty_cache() - self.models["ocr"].to(self.device) + self.model_zoo.get("ocr").to(self.device) ind_tot = 0 #cv2.imwrite('./img_out.png', image_page) @@ -4839,7 +4682,7 @@ class Eynollah: img_croped = img_poly_on_img[y:y+h, x:x+w, :] #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) text_ocr = self.return_ocr_of_textline_without_common_section( - img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot) + img_croped, self.model_zoo.get("ocr"), self.model_zoo.get("trocr_processor"), self.device, w, h2w_ratio, ind_tot) ocr_textline_in_textregion.append(text_ocr) ind_tot = ind_tot +1 ocr_all_textlines.append(ocr_textline_in_textregion) @@ -4876,964 +4719,3 @@ class Eynollah: return pcgts -class Eynollah_ocr: - def __init__( - self, - dir_models, - model_name=None, - dir_xmls=None, - tr_ocr=False, - batch_size=None, - export_textline_images_and_text=False, - do_not_mask_with_textline_contour=False, - pref_of_dataset=None, - min_conf_value_of_textline_text : Optional[float]=None, - logger=None, - ): - self.model_name = model_name - self.tr_ocr = tr_ocr - self.export_textline_images_and_text = export_textline_images_and_text - self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour - self.pref_of_dataset = pref_of_dataset - self.logger = logger if logger else getLogger('eynollah') - - if not export_textline_images_and_text: - if min_conf_value_of_textline_text: - self.min_conf_value_of_textline_text = float(min_conf_value_of_textline_text) - else: - self.min_conf_value_of_textline_text = 0.3 - if tr_ocr: - self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - if self.model_name: - self.model_ocr_dir = self.model_name - else: - self.model_ocr_dir = dir_models + "/model_eynollah_ocr_trocr_20250919" - self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) - self.model_ocr.to(self.device) - if not batch_size: - self.b_s = 2 - else: - self.b_s = int(batch_size) - - else: - if self.model_name: - self.model_ocr_dir = self.model_name - else: - self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250930" - model_ocr = load_model(self.model_ocr_dir , compile=False) - - self.prediction_model = tf.keras.models.Model( - model_ocr.get_layer(name = "image").input, - model_ocr.get_layer(name = "dense2").output) - if not batch_size: - self.b_s = 8 - else: - self.b_s = int(batch_size) - - with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file: - characters = json.load(config_file) - - AUTOTUNE = tf.data.AUTOTUNE - - # Mapping characters to integers. - char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) - - # Mapping integers back to original characters. - self.num_to_char = StringLookup( - vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True - ) - self.end_character = len(characters) + 2 - - def run(self, overwrite: bool = False, - dir_in: Optional[str] = None, - dir_in_bin: Optional[str] = None, - image_filename: Optional[str] = None, - dir_xmls: Optional[str] = None, - dir_out_image_text: Optional[str] = None, - dir_out: Optional[str] = None, - ): - if dir_in: - ls_imgs = [os.path.join(dir_in, image_filename) - for image_filename in filter(is_image_filename, - os.listdir(dir_in))] - else: - ls_imgs = [image_filename] - - if self.tr_ocr: - tr_ocr_input_height_and_width = 384 - for dir_img in ls_imgs: - file_name = Path(dir_img).stem - dir_xml = os.path.join(dir_xmls, file_name+'.xml') - out_file_ocr = os.path.join(dir_out, file_name+'.xml') - - if os.path.exists(out_file_ocr): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) - else: - self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) - continue - - img = cv2.imread(dir_img) - - if dir_out_image_text: - out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') - image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") - draw = ImageDraw.Draw(image_text) - total_bb_coordinates = [] - - ##file_name = Path(dir_xmls).stem - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] - - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - - - - cropped_lines = [] - cropped_lines_region_indexer = [] - cropped_lines_meging_indexing = [] - - extracted_texts = [] - - indexer_text_region = 0 - indexer_b_s = 0 - - for nn in root1.iter(region_tags): - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - x,y,w,h = cv2.boundingRect(textline_coords) - - if dir_out_image_text: - total_bb_coordinates.append([x,y,w,h]) - - h2w_ratio = h/float(w) - - img_poly_on_img = np.copy(img) - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - img_crop[mask_poly==0] = 255 - - self.logger.debug("processing %d lines for '%s'", - len(cropped_lines), nn.attrib['id']) - if h2w_ratio > 0.1: - cropped_lines.append(resize_image(img_crop, - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width) ) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - splited_images, _ = return_textlines_split_if_needed(img_crop, None) - #print(splited_images) - if splited_images: - cropped_lines.append(resize_image(splited_images[0], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - cropped_lines.append(resize_image(splited_images[1], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(-1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - cropped_lines.append(img_crop) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - - indexer_text_region = indexer_text_region +1 - - if indexer_b_s!=0: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) - generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - ####extracted_texts = [] - ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - ####for i in range(n_iterations): - ####if i==(n_iterations-1): - ####n_start = i*self.b_s - ####imgs = cropped_lines[n_start:] - ####else: - ####n_start = i*self.b_s - ####n_end = (i+1)*self.b_s - ####imgs = cropped_lines[n_start:n_end] - ####pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values - ####generated_ids_merged = self.model_ocr.generate( - #### pixel_values_merged.to(self.device)) - ####generated_text_merged = self.processor.batch_decode( - #### generated_ids_merged, skip_special_tokens=True) - - ####extracted_texts = extracted_texts + generated_text_merged - - del cropped_lines - gc.collect() - - extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - #print(extracted_texts_merged, len(extracted_texts_merged)) - - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - if dir_out_image_text: - - #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "Charis-Regular.ttf" - with importlib_resources.as_file(font) as font: - font = ImageFont.truetype(font=font, size=40) - - for indexer_text, bb_ind in enumerate(total_bb_coordinates): - - - x_bb = bb_ind[0] - y_bb = bb_ind[1] - w_bb = bb_ind[2] - h_bb = bb_ind[3] - - font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], - font.path, w_bb, int(h_bb*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - - text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally - text_y = y_bb + (h_bb - text_height) // 2 # Center vertically - - # Draw the text - draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) - image_text.save(out_image_with_text) - - #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') - #######text_by_textregion = [] - #######for ind in unique_cropped_lines_region_indexer: - #######ind = np.array(cropped_lines_region_indexer)==ind - #######extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - #######text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer) == ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - #id_textregion = nn.attrib['id'] - #id_textregions.append(id_textregion) - #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - ##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - ##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] - else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 - - ###sample_order = [(id_to_order[tid], text) - ### for tid, text in zip(id_textregions, textregions_by_existing_ids) - ### if tid in id_to_order] - - ##ordered_texts_sample = [text for _, text in sorted(sample_order)] - ##tot_page_text = ' '.join(ordered_texts_sample) - - ##for page_element in root1.iter(link+'Page'): - ##text_page = ET.SubElement(page_element, 'TextEquiv') - ##unicode_textpage = ET.SubElement(text_page, 'Unicode') - ##unicode_textpage.text = tot_page_text - - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) - else: - ###max_len = 280#512#280#512 - ###padding_token = 1500#299#1500#299 - image_width = 512#max_len * 4 - image_height = 32 - - - img_size=(image_width, image_height) - - for dir_img in ls_imgs: - file_name = Path(dir_img).stem - dir_xml = os.path.join(dir_xmls, file_name+'.xml') - out_file_ocr = os.path.join(dir_out, file_name+'.xml') - - if os.path.exists(out_file_ocr): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) - else: - self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) - continue - - img = cv2.imread(dir_img) - if dir_in_bin is not None: - cropped_lines_bin = [] - dir_img_bin = os.path.join(dir_in_bin, file_name+'.png') - img_bin = cv2.imread(dir_img_bin) - - if dir_out_image_text: - out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') - image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") - draw = ImageDraw.Draw(image_text) - total_bb_coordinates = [] - - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] - - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - - cropped_lines = [] - cropped_lines_ver_index = [] - cropped_lines_region_indexer = [] - cropped_lines_meging_indexing = [] - - tinl = time.time() - indexer_text_region = 0 - indexer_textlines = 0 - for nn in root1.iter(region_tags): - try: - type_textregion = nn.attrib['type'] - except: - type_textregion = 'paragraph' - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - - x,y,w,h = cv2.boundingRect(textline_coords) - - angle_radians = math.atan2(h, w) - # Convert to degrees - angle_degrees = math.degrees(angle_radians) - if type_textregion=='drop-capital': - angle_degrees = 0 - - if dir_out_image_text: - total_bb_coordinates.append([x,y,w,h]) - - w_scaled = w * image_height/float(h) - - img_poly_on_img = np.copy(img) - if dir_in_bin is not None: - img_poly_on_img_bin = np.copy(img_bin) - img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] - - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - - if self.export_textline_images_and_text: - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - - else: - # print(file_name, angle_degrees, w*h, - # mask_poly[:,:,0].sum(), - # mask_poly[:,:,0].sum() /float(w*h) , - # 'didi') - - if angle_degrees > 3: - better_des_slope = get_orientation_moments(textline_coords) - - img_crop = rotate_image_with_padding(img_crop, better_des_slope) - if dir_in_bin is not None: - img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) - - mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) - mask_poly = mask_poly.astype('uint8') - - #new bounding box - x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) - - mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] - img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] - - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - if dir_in_bin is not None: - img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] - if not self.do_not_mask_with_textline_contour: - img_crop_bin[mask_poly==0] = 255 - - if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: - if dir_in_bin is not None: - img_crop, img_crop_bin = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly, img_crop_bin) - else: - img_crop, _ = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly) - - else: - better_des_slope = 0 - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - if dir_in_bin is not None: - if not self.do_not_mask_with_textline_contour: - img_crop_bin[mask_poly==0] = 255 - if type_textregion=='drop-capital': - pass - else: - if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: - if dir_in_bin is not None: - img_crop, img_crop_bin = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly, img_crop_bin) - else: - img_crop, _ = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly) - - if not self.export_textline_images_and_text: - if w_scaled < 750:#1.5*image_width: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop, image_height, image_width) - cropped_lines.append(img_fin) - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - cropped_lines_meging_indexing.append(0) - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop_bin, image_height, image_width) - cropped_lines_bin.append(img_fin) - else: - splited_images, splited_images_bin = return_textlines_split_if_needed( - img_crop, img_crop_bin if dir_in_bin is not None else None) - if splited_images: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images[0], image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(1) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images[1], image_height, image_width) - - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(-1) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images_bin[0], image_height, image_width) - cropped_lines_bin.append(img_fin) - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images_bin[1], image_height, image_width) - cropped_lines_bin.append(img_fin) - - else: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop, image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(0) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop_bin, image_height, image_width) - cropped_lines_bin.append(img_fin) - - if self.export_textline_images_and_text: - if img_crop.shape[0]==0 or img_crop.shape[1]==0: - pass - else: - if child_textlines.tag.endswith("TextEquiv"): - for cheild_text in child_textlines: - if cheild_text.tag.endswith("Unicode"): - textline_text = cheild_text.text - if textline_text: - base_name = os.path.join( - dir_out, file_name + '_line_' + str(indexer_textlines)) - if self.pref_of_dataset: - base_name += '_' + self.pref_of_dataset - if not self.do_not_mask_with_textline_contour: - base_name += '_masked' - - with open(base_name + '.txt', 'w') as text_file: - text_file.write(textline_text) - cv2.imwrite(base_name + '.png', img_crop) - indexer_textlines+=1 - - if not self.export_textline_images_and_text: - indexer_text_region = indexer_text_region +1 - - if not self.export_textline_images_and_text: - extracted_texts = [] - extracted_conf_value = [] - - n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - for i in range(n_iterations): - if i==(n_iterations-1): - n_start = i*self.b_s - imgs = cropped_lines[n_start:] - imgs = np.array(imgs) - imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) - - ver_imgs = np.array( cropped_lines_ver_index[n_start:] ) - indices_ver = np.where(ver_imgs == 1)[0] - - #print(indices_ver, 'indices_ver') - if len(indices_ver)>0: - imgs_ver_flipped = imgs[indices_ver, : ,: ,:] - imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - - else: - imgs_ver_flipped = None - - if dir_in_bin is not None: - imgs_bin = cropped_lines_bin[n_start:] - imgs_bin = np.array(imgs_bin) - imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) - - if len(indices_ver)>0: - imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] - imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - - else: - imgs_bin_ver_flipped = None - else: - n_start = i*self.b_s - n_end = (i+1)*self.b_s - imgs = cropped_lines[n_start:n_end] - imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) - - ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] ) - indices_ver = np.where(ver_imgs == 1)[0] - #print(indices_ver, 'indices_ver') - - if len(indices_ver)>0: - imgs_ver_flipped = imgs[indices_ver, : ,: ,:] - imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - else: - imgs_ver_flipped = None - - - if dir_in_bin is not None: - imgs_bin = cropped_lines_bin[n_start:n_end] - imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) - - - if len(indices_ver)>0: - imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] - imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - else: - imgs_bin_ver_flipped = None - - - self.logger.debug("processing next %d lines", len(imgs)) - preds = self.prediction_model.predict(imgs, verbose=0) - - if len(indices_ver)>0: - preds_flipped = self.prediction_model.predict(imgs_ver_flipped, verbose=0) - preds_max_fliped = np.max(preds_flipped, axis=2 ) - preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) - pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character - masked_means_flipped = \ - np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) - masked_means_flipped[np.isnan(masked_means_flipped)] = 0 - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - masked_means[np.isnan(masked_means)] = 0 - - masked_means_ver = masked_means[indices_ver] - #print(masked_means_ver, 'pred_max_not_unk') - - indices_where_flipped_conf_value_is_higher = \ - np.where(masked_means_flipped > masked_means_ver)[0] - - #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') - if len(indices_where_flipped_conf_value_is_higher)>0: - indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] - preds[indices_to_be_replaced,:,:] = \ - preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] - if dir_in_bin is not None: - preds_bin = self.prediction_model.predict(imgs_bin, verbose=0) - - if len(indices_ver)>0: - preds_flipped = self.prediction_model.predict(imgs_bin_ver_flipped, verbose=0) - preds_max_fliped = np.max(preds_flipped, axis=2 ) - preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) - pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character - masked_means_flipped = \ - np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) - masked_means_flipped[np.isnan(masked_means_flipped)] = 0 - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - masked_means[np.isnan(masked_means)] = 0 - - masked_means_ver = masked_means[indices_ver] - #print(masked_means_ver, 'pred_max_not_unk') - - indices_where_flipped_conf_value_is_higher = \ - np.where(masked_means_flipped > masked_means_ver)[0] - - #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') - if len(indices_where_flipped_conf_value_is_higher)>0: - indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] - preds_bin[indices_to_be_replaced,:,:] = \ - preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] - - preds = (preds + preds_bin) / 2. - - pred_texts = decode_batch_predictions(preds, self.num_to_char) - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - - for ib in range(imgs.shape[0]): - pred_texts_ib = pred_texts[ib].replace("[UNK]", "") - if masked_means[ib] >= self.min_conf_value_of_textline_text: - extracted_texts.append(pred_texts_ib) - extracted_conf_value.append(masked_means[ib]) - else: - extracted_texts.append("") - extracted_conf_value.append(0) - del cropped_lines - if dir_in_bin is not None: - del cropped_lines_bin - gc.collect() - - extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_conf_value_merged = [extracted_conf_value[ind] - if cropped_lines_meging_indexing[ind]==0 - else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_conf_value_merged = [extracted_conf_value_merged[ind_cfm] - for ind_cfm in range(len(extracted_texts_merged)) - if extracted_texts_merged[ind_cfm] is not None] - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - if dir_out_image_text: - #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "Charis-Regular.ttf" - with importlib_resources.as_file(font) as font: - font = ImageFont.truetype(font=font, size=40) - - for indexer_text, bb_ind in enumerate(total_bb_coordinates): - x_bb = bb_ind[0] - y_bb = bb_ind[1] - w_bb = bb_ind[2] - h_bb = bb_ind[3] - - font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], - font.path, w_bb, int(h_bb*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - - text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally - text_y = y_bb + (h_bb - text_height) // 2 # Center vertically - - # Draw the text - draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) - image_text.save(out_image_with_text) - - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer)==ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - #print(text_by_textregion, 'text_by_textregiontext_by_textregiontext_by_textregiontext_by_textregiontext_by_textregion') - - ###index_tot_regions = [] - ###tot_region_ref = [] - - ###for jj in root1.iter(link+'RegionRefIndexed'): - ###index_tot_regions.append(jj.attrib['index']) - ###tot_region_ref.append(jj.attrib['regionRef']) - - ###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)} - - #id_textregions = [] - #textregions_by_existing_ids = [] - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - #id_textregion = nn.attrib['id'] - #id_textregions.append(id_textregion) - #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - childtest3.set('conf', - f"{extracted_conf_value_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] - else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 - - ###sample_order = [(id_to_order[tid], text) - ### for tid, text in zip(id_textregions, textregions_by_existing_ids) - ### if tid in id_to_order] - - ##ordered_texts_sample = [text for _, text in sorted(sample_order)] - ##tot_page_text = ' '.join(ordered_texts_sample) - - ##for page_element in root1.iter(link+'Page'): - ##text_page = ET.SubElement(page_element, 'TextEquiv') - ##unicode_textpage = ET.SubElement(text_page, 'Unicode') - ##unicode_textpage.text = tot_page_text - - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) - #print("Job done in %.1fs", time.time() - t0) diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py new file mode 100644 index 0000000..cfd410c --- /dev/null +++ b/src/eynollah/eynollah_ocr.py @@ -0,0 +1,998 @@ +# pyright: reportPossiblyUnboundVariable=false + +from logging import Logger, getLogger +from typing import Optional +from pathlib import Path +import os +import json +import gc +import sys +import math +import time + +from keras.layers import StringLookup +import cv2 +import xml.etree.ElementTree as ET +import tensorflow as tf +from keras.models import load_model +from PIL import Image, ImageDraw, ImageFont +import numpy as np +from eynollah.model_zoo import EynollahModelZoo +try: + import torch +except ImportError: + torch = None + + +from .utils import is_image_filename +from .utils.resize import resize_image +from .utils.utils_ocr import ( + break_curved_line_into_small_pieces_and_then_merge, + decode_batch_predictions, + fit_text_single_line, + get_contours_and_bounding_boxes, + get_orientation_moments, + preprocess_and_resize_image_for_ocrcnn_model, + return_textlines_split_if_needed, + rotate_image_with_padding, +) + +# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files +if sys.version_info < (3, 10): + import importlib_resources +else: + import importlib.resources as importlib_resources + +try: + from transformers import TrOCRProcessor, VisionEncoderDecoderModel +except ImportError: + TrOCRProcessor = VisionEncoderDecoderModel = None + +class Eynollah_ocr: + def __init__( + self, + dir_models, + model_name=None, + dir_xmls=None, + tr_ocr=False, + batch_size: Optional[int]=None, + export_textline_images_and_text: bool=False, + do_not_mask_with_textline_contour: bool=False, + pref_of_dataset=None, + min_conf_value_of_textline_text : float=0.3, + logger: Optional[Logger]=None, + ): + self.tr_ocr = tr_ocr + self.export_textline_images_and_text = export_textline_images_and_text + self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour + self.pref_of_dataset = pref_of_dataset + self.logger = logger if logger else getLogger('eynollah') + self.model_zoo = EynollahModelZoo(basedir=dir_models) + + # TODO: Properly document what 'export_textline_images_and_text' is about + if export_textline_images_and_text: + self.logger.info("export_textline_images_and_text was set, so no actual models are loaded") + return + + self.min_conf_value_of_textline_text = min_conf_value_of_textline_text + self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size + + if tr_ocr: + self.model_zoo.load_model('trocr_processor', '') + if model_name: + self.model_zoo.load_model('ocr', 'tr', model_name) + else: + self.model_zoo.load_model('ocr', 'tr') + self.model_zoo.get('ocr').to(self.device) + else: + if model_name: + self.model_zoo.load_model('ocr', '', model_name) + else: + self.model_zoo.load_model('ocr', '') + self.model_zoo.load_model('num_to_char') + self.end_character = len(self.model_zoo.load_model('characters')) + 2 + + @property + def device(self): + if torch.cuda.is_available(): + self.logger.info("Using GPU acceleration") + return torch.device("cuda:0") + else: + self.logger.info("Using CPU processing") + return torch.device("cpu") + + def run(self, overwrite: bool = False, + dir_in: Optional[str] = None, + dir_in_bin: Optional[str] = None, + image_filename: Optional[str] = None, + dir_xmls: Optional[str] = None, + dir_out_image_text: Optional[str] = None, + dir_out: Optional[str] = None, + ): + if dir_in: + ls_imgs = [os.path.join(dir_in, image_filename) + for image_filename in filter(is_image_filename, + os.listdir(dir_in))] + else: + assert image_filename + ls_imgs = [image_filename] + + if self.tr_ocr: + tr_ocr_input_height_and_width = 384 + for dir_img in ls_imgs: + file_name = Path(dir_img).stem + assert dir_xmls # FIXME: check the logic + dir_xml = os.path.join(dir_xmls, file_name+'.xml') + assert dir_out # FIXME: check the logic + out_file_ocr = os.path.join(dir_out, file_name+'.xml') + + if os.path.exists(out_file_ocr): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) + else: + self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) + continue + + img = cv2.imread(dir_img) + + if dir_out_image_text: + out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') + image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") + draw = ImageDraw.Draw(image_text) + total_bb_coordinates = [] + + ##file_name = Path(dir_xmls).stem + tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) + + + + cropped_lines = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + extracted_texts = [] + + indexer_text_region = 0 + indexer_b_s = 0 + + for nn in root1.iter(region_tags): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [int(x.split(',')[0]), + int(x.split(',')[1]) ] + for x in p_h] ) + x,y,w,h = cv2.boundingRect(textline_coords) + + if dir_out_image_text: + total_bb_coordinates.append([x,y,w,h]) + + h2w_ratio = h/float(w) + + img_poly_on_img = np.copy(img) + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + img_crop[mask_poly==0] = 255 + + self.logger.debug("processing %d lines for '%s'", + len(cropped_lines), nn.attrib['id']) + if h2w_ratio > 0.1: + cropped_lines.append(resize_image(img_crop, + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width) ) + cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + else: + splited_images, _ = return_textlines_split_if_needed(img_crop, None) + #print(splited_images) + if splited_images: + cropped_lines.append(resize_image(splited_images[0], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + cropped_lines.append(resize_image(splited_images[1], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(-1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + else: + cropped_lines.append(img_crop) + cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + + indexer_text_region = indexer_text_region +1 + + if indexer_b_s!=0: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('processor').batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + ####extracted_texts = [] + ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + ####for i in range(n_iterations): + ####if i==(n_iterations-1): + ####n_start = i*self.b_s + ####imgs = cropped_lines[n_start:] + ####else: + ####n_start = i*self.b_s + ####n_end = (i+1)*self.b_s + ####imgs = cropped_lines[n_start:n_end] + ####pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values + ####generated_ids_merged = self.model_ocr.generate( + #### pixel_values_merged.to(self.device)) + ####generated_text_merged = self.model_zoo.get('processor').batch_decode( + #### generated_ids_merged, skip_special_tokens=True) + + ####extracted_texts = extracted_texts + generated_text_merged + + del cropped_lines + gc.collect() + + extracted_texts_merged = [extracted_texts[ind] + if cropped_lines_meging_indexing[ind]==0 + else extracted_texts[ind]+" "+extracted_texts[ind+1] + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + #print(extracted_texts_merged, len(extracted_texts_merged)) + + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + + if dir_out_image_text: + + #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = importlib_resources.files(__package__) / "Charis-Regular.ttf" + with importlib_resources.as_file(font) as font: + font = ImageFont.truetype(font=font, size=40) + + for indexer_text, bb_ind in enumerate(total_bb_coordinates): + + + x_bb = bb_ind[0] + y_bb = bb_ind[1] + w_bb = bb_ind[2] + h_bb = bb_ind[3] + + font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], + font.path, w_bb, int(h_bb*0.4) ) + + ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) + + text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally + text_y = y_bb + (h_bb - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) + image_text.save(out_image_with_text) + + #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') + #######text_by_textregion = [] + #######for ind in unique_cropped_lines_region_indexer: + #######ind = np.array(cropped_lines_region_indexer)==ind + #######extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] + #######text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + ind = np.array(cropped_lines_region_indexer) == ind + extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] + if len(extracted_texts_merged_un)>1: + text_by_textregion_ind = "" + next_glue = "" + for indt in range(len(extracted_texts_merged_un)): + if (extracted_texts_merged_un[indt].endswith('⸗') or + extracted_texts_merged_un[indt].endswith('-') or + extracted_texts_merged_un[indt].endswith('¬')): + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] + next_glue = "" + else: + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] + next_glue = " " + text_by_textregion.append(text_by_textregion_ind) + else: + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + + indexer = 0 + indexer_textregion = 0 + for nn in root1.iter(region_tags): + #id_textregion = nn.attrib['id'] + #id_textregions.append(id_textregion) + #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) + + is_textregion_text = False + for childtest in nn: + if childtest.tag.endswith("TextEquiv"): + is_textregion_text = True + + if not is_textregion_text: + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + is_textline_text = False + for childtest2 in child_textregion: + if childtest2.tag.endswith("TextEquiv"): + is_textline_text = True + + + if not is_textline_text: + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + ##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + else: + for childtest3 in child_textregion: + if childtest3.tag.endswith("TextEquiv"): + for child_uc in childtest3: + if child_uc.tag.endswith("Unicode"): + ##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + child_uc.text = extracted_texts_merged[indexer] + + indexer = indexer + 1 + has_textline = True + if has_textline: + if is_textregion_text: + for child4 in nn: + if child4.tag.endswith("TextEquiv"): + for childtr_uc in child4: + if childtr_uc.tag.endswith("Unicode"): + childtr_uc.text = text_by_textregion[indexer_textregion] + else: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + ###sample_order = [(id_to_order[tid], text) + ### for tid, text in zip(id_textregions, textregions_by_existing_ids) + ### if tid in id_to_order] + + ##ordered_texts_sample = [text for _, text in sorted(sample_order)] + ##tot_page_text = ' '.join(ordered_texts_sample) + + ##for page_element in root1.iter(link+'Page'): + ##text_page = ET.SubElement(page_element, 'TextEquiv') + ##unicode_textpage = ET.SubElement(text_page, 'Unicode') + ##unicode_textpage.text = tot_page_text + + ET.register_namespace("",name_space) + tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) + else: + ###max_len = 280#512#280#512 + ###padding_token = 1500#299#1500#299 + image_width = 512#max_len * 4 + image_height = 32 + + + img_size=(image_width, image_height) + + for dir_img in ls_imgs: + file_name = Path(dir_img).stem + dir_xml = os.path.join(dir_xmls, file_name+'.xml') + out_file_ocr = os.path.join(dir_out, file_name+'.xml') + + if os.path.exists(out_file_ocr): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) + else: + self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) + continue + + img = cv2.imread(dir_img) + if dir_in_bin is not None: + cropped_lines_bin = [] + dir_img_bin = os.path.join(dir_in_bin, file_name+'.png') + img_bin = cv2.imread(dir_img_bin) + + if dir_out_image_text: + out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') + image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") + draw = ImageDraw.Draw(image_text) + total_bb_coordinates = [] + + tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) + + cropped_lines = [] + cropped_lines_ver_index = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + tinl = time.time() + indexer_text_region = 0 + indexer_textlines = 0 + for nn in root1.iter(region_tags): + try: + type_textregion = nn.attrib['type'] + except: + type_textregion = 'paragraph' + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [int(x.split(',')[0]), + int(x.split(',')[1]) ] + for x in p_h] ) + + x,y,w,h = cv2.boundingRect(textline_coords) + + angle_radians = math.atan2(h, w) + # Convert to degrees + angle_degrees = math.degrees(angle_radians) + if type_textregion=='drop-capital': + angle_degrees = 0 + + if dir_out_image_text: + total_bb_coordinates.append([x,y,w,h]) + + w_scaled = w * image_height/float(h) + + img_poly_on_img = np.copy(img) + if dir_in_bin is not None: + img_poly_on_img_bin = np.copy(img_bin) + img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] + + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + + if self.export_textline_images_and_text: + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + + else: + # print(file_name, angle_degrees, w*h, + # mask_poly[:,:,0].sum(), + # mask_poly[:,:,0].sum() /float(w*h) , + # 'didi') + + if angle_degrees > 3: + better_des_slope = get_orientation_moments(textline_coords) + + img_crop = rotate_image_with_padding(img_crop, better_des_slope) + if dir_in_bin is not None: + img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) + + mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) + mask_poly = mask_poly.astype('uint8') + + #new bounding box + x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) + + mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] + img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] + + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + if dir_in_bin is not None: + img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] + if not self.do_not_mask_with_textline_contour: + img_crop_bin[mask_poly==0] = 255 + + if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: + if dir_in_bin is not None: + img_crop, img_crop_bin = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly, img_crop_bin) + else: + img_crop, _ = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly) + + else: + better_des_slope = 0 + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + if dir_in_bin is not None: + if not self.do_not_mask_with_textline_contour: + img_crop_bin[mask_poly==0] = 255 + if type_textregion=='drop-capital': + pass + else: + if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: + if dir_in_bin is not None: + img_crop, img_crop_bin = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly, img_crop_bin) + else: + img_crop, _ = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly) + + if not self.export_textline_images_and_text: + if w_scaled < 750:#1.5*image_width: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop, image_height, image_width) + cropped_lines.append(img_fin) + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + cropped_lines_meging_indexing.append(0) + if dir_in_bin is not None: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop_bin, image_height, image_width) + cropped_lines_bin.append(img_fin) + else: + splited_images, splited_images_bin = return_textlines_split_if_needed( + img_crop, img_crop_bin if dir_in_bin is not None else None) + if splited_images: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images[0], image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(1) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images[1], image_height, image_width) + + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(-1) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + if dir_in_bin is not None: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images_bin[0], image_height, image_width) + cropped_lines_bin.append(img_fin) + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images_bin[1], image_height, image_width) + cropped_lines_bin.append(img_fin) + + else: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop, image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(0) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + if dir_in_bin is not None: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop_bin, image_height, image_width) + cropped_lines_bin.append(img_fin) + + if self.export_textline_images_and_text: + if img_crop.shape[0]==0 or img_crop.shape[1]==0: + pass + else: + if child_textlines.tag.endswith("TextEquiv"): + for cheild_text in child_textlines: + if cheild_text.tag.endswith("Unicode"): + textline_text = cheild_text.text + if textline_text: + base_name = os.path.join( + dir_out, file_name + '_line_' + str(indexer_textlines)) + if self.pref_of_dataset: + base_name += '_' + self.pref_of_dataset + if not self.do_not_mask_with_textline_contour: + base_name += '_masked' + + with open(base_name + '.txt', 'w') as text_file: + text_file.write(textline_text) + cv2.imwrite(base_name + '.png', img_crop) + indexer_textlines+=1 + + if not self.export_textline_images_and_text: + indexer_text_region = indexer_text_region +1 + + if not self.export_textline_images_and_text: + extracted_texts = [] + extracted_conf_value = [] + + n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + for i in range(n_iterations): + if i==(n_iterations-1): + n_start = i*self.b_s + imgs = cropped_lines[n_start:] + imgs = np.array(imgs) + imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) + + ver_imgs = np.array( cropped_lines_ver_index[n_start:] ) + indices_ver = np.where(ver_imgs == 1)[0] + + #print(indices_ver, 'indices_ver') + if len(indices_ver)>0: + imgs_ver_flipped = imgs[indices_ver, : ,: ,:] + imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + + else: + imgs_ver_flipped = None + + if dir_in_bin is not None: + imgs_bin = cropped_lines_bin[n_start:] + imgs_bin = np.array(imgs_bin) + imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) + + if len(indices_ver)>0: + imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] + imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + + else: + imgs_bin_ver_flipped = None + else: + n_start = i*self.b_s + n_end = (i+1)*self.b_s + imgs = cropped_lines[n_start:n_end] + imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) + + ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] ) + indices_ver = np.where(ver_imgs == 1)[0] + #print(indices_ver, 'indices_ver') + + if len(indices_ver)>0: + imgs_ver_flipped = imgs[indices_ver, : ,: ,:] + imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + else: + imgs_ver_flipped = None + + + if dir_in_bin is not None: + imgs_bin = cropped_lines_bin[n_start:n_end] + imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) + + + if len(indices_ver)>0: + imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] + imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + else: + imgs_bin_ver_flipped = None + + + self.logger.debug("processing next %d lines", len(imgs)) + preds = self.model_zoo.get('ocr').predict(imgs, verbose=0) + + if len(indices_ver)>0: + preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0) + preds_max_fliped = np.max(preds_flipped, axis=2 ) + preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) + pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character + masked_means_flipped = \ + np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) + masked_means_flipped[np.isnan(masked_means_flipped)] = 0 + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + masked_means[np.isnan(masked_means)] = 0 + + masked_means_ver = masked_means[indices_ver] + #print(masked_means_ver, 'pred_max_not_unk') + + indices_where_flipped_conf_value_is_higher = \ + np.where(masked_means_flipped > masked_means_ver)[0] + + #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') + if len(indices_where_flipped_conf_value_is_higher)>0: + indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] + preds[indices_to_be_replaced,:,:] = \ + preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] + if dir_in_bin is not None: + preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0) + + if len(indices_ver)>0: + preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0) + preds_max_fliped = np.max(preds_flipped, axis=2 ) + preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) + pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character + masked_means_flipped = \ + np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) + masked_means_flipped[np.isnan(masked_means_flipped)] = 0 + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + masked_means[np.isnan(masked_means)] = 0 + + masked_means_ver = masked_means[indices_ver] + #print(masked_means_ver, 'pred_max_not_unk') + + indices_where_flipped_conf_value_is_higher = \ + np.where(masked_means_flipped > masked_means_ver)[0] + + #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') + if len(indices_where_flipped_conf_value_is_higher)>0: + indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] + preds_bin[indices_to_be_replaced,:,:] = \ + preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] + + preds = (preds + preds_bin) / 2. + + pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char')) + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + + for ib in range(imgs.shape[0]): + pred_texts_ib = pred_texts[ib].replace("[UNK]", "") + if masked_means[ib] >= self.min_conf_value_of_textline_text: + extracted_texts.append(pred_texts_ib) + extracted_conf_value.append(masked_means[ib]) + else: + extracted_texts.append("") + extracted_conf_value.append(0) + del cropped_lines + if dir_in_bin is not None: + del cropped_lines_bin + gc.collect() + + extracted_texts_merged = [extracted_texts[ind] + if cropped_lines_meging_indexing[ind]==0 + else extracted_texts[ind]+" "+extracted_texts[ind+1] + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_conf_value_merged = [extracted_conf_value[ind] + if cropped_lines_meging_indexing[ind]==0 + else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_conf_value_merged = [extracted_conf_value_merged[ind_cfm] + for ind_cfm in range(len(extracted_texts_merged)) + if extracted_texts_merged[ind_cfm] is not None] + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + + if dir_out_image_text: + #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = importlib_resources.files(__package__) / "Charis-Regular.ttf" + with importlib_resources.as_file(font) as font: + font = ImageFont.truetype(font=font, size=40) + + for indexer_text, bb_ind in enumerate(total_bb_coordinates): + x_bb = bb_ind[0] + y_bb = bb_ind[1] + w_bb = bb_ind[2] + h_bb = bb_ind[3] + + font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], + font.path, w_bb, int(h_bb*0.4) ) + + ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) + + text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally + text_y = y_bb + (h_bb - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) + image_text.save(out_image_with_text) + + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + ind = np.array(cropped_lines_region_indexer)==ind + extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] + if len(extracted_texts_merged_un)>1: + text_by_textregion_ind = "" + next_glue = "" + for indt in range(len(extracted_texts_merged_un)): + if (extracted_texts_merged_un[indt].endswith('⸗') or + extracted_texts_merged_un[indt].endswith('-') or + extracted_texts_merged_un[indt].endswith('¬')): + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] + next_glue = "" + else: + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] + next_glue = " " + text_by_textregion.append(text_by_textregion_ind) + else: + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + #print(text_by_textregion, 'text_by_textregiontext_by_textregiontext_by_textregiontext_by_textregiontext_by_textregion') + + ###index_tot_regions = [] + ###tot_region_ref = [] + + ###for jj in root1.iter(link+'RegionRefIndexed'): + ###index_tot_regions.append(jj.attrib['index']) + ###tot_region_ref.append(jj.attrib['regionRef']) + + ###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)} + + #id_textregions = [] + #textregions_by_existing_ids = [] + indexer = 0 + indexer_textregion = 0 + for nn in root1.iter(region_tags): + #id_textregion = nn.attrib['id'] + #id_textregions.append(id_textregion) + #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) + + is_textregion_text = False + for childtest in nn: + if childtest.tag.endswith("TextEquiv"): + is_textregion_text = True + + if not is_textregion_text: + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + is_textline_text = False + for childtest2 in child_textregion: + if childtest2.tag.endswith("TextEquiv"): + is_textline_text = True + + + if not is_textline_text: + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + else: + for childtest3 in child_textregion: + if childtest3.tag.endswith("TextEquiv"): + for child_uc in childtest3: + if child_uc.tag.endswith("Unicode"): + childtest3.set('conf', + f"{extracted_conf_value_merged[indexer]:.2f}") + child_uc.text = extracted_texts_merged[indexer] + + indexer = indexer + 1 + has_textline = True + if has_textline: + if is_textregion_text: + for child4 in nn: + if child4.tag.endswith("TextEquiv"): + for childtr_uc in child4: + if childtr_uc.tag.endswith("Unicode"): + childtr_uc.text = text_by_textregion[indexer_textregion] + else: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + ###sample_order = [(id_to_order[tid], text) + ### for tid, text in zip(id_textregions, textregions_by_existing_ids) + ### if tid in id_to_order] + + ##ordered_texts_sample = [text for _, text in sorted(sample_order)] + ##tot_page_text = ' '.join(ordered_texts_sample) + + ##for page_element in root1.iter(link+'Page'): + ##text_page = ET.SubElement(page_element, 'TextEquiv') + ##unicode_textpage = ET.SubElement(text_page, 'Unicode') + ##unicode_textpage.text = tot_page_text + + ET.register_namespace("",name_space) + tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) + #print("Job done in %.1fs", time.time() - t0) diff --git a/src/eynollah/image_enhancer.py b/src/eynollah/image_enhancer.py index 9247efe..cec8877 100644 --- a/src/eynollah/image_enhancer.py +++ b/src/eynollah/image_enhancer.py @@ -5,24 +5,25 @@ Image enhancer. The output can be written as same scale of input or in new predi from logging import Logger import os import time -from typing import Optional +from typing import Dict, Optional from pathlib import Path import gc import cv2 +from keras.models import Model import numpy as np from ocrd_utils import getLogger, tf_disable_interactive_logs import tensorflow as tf from skimage.morphology import skeletonize -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.pil_cv2 import pil2cv from .utils import ( is_image_filename, crop_image_inside_box ) -from .eynollah import PatchEncoder, Patches +from .patch_encoder import PatchEncoder, Patches DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) @@ -50,11 +51,9 @@ class Enhancer: self.num_col_lower = num_col_lower self.logger = logger if logger else getLogger('enhancement') - self.dir_models = dir_models - self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425" - self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425" - self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425" - self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915" + self.model_zoo = EynollahModelZoo(basedir=dir_models) + for v in ['binarization', 'enhancement', 'col_classifier', 'page']: + self.model_zoo.load_model(v) try: for device in tf.config.list_physical_devices('GPU'): @@ -62,11 +61,6 @@ class Enhancer: except: self.logger.warning("no GPU device available") - self.model_page = self.our_load_model(self.model_page_dir) - self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) - self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) - self.model_bin = self.our_load_model(self.model_dir_of_binarization) - def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} if image_filename: @@ -102,24 +96,12 @@ class Enhancer: def isNaN(self, num): return num != num - - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - img_height_model = self.model_enhancement.layers[-1].output_shape[1] - img_width_model = self.model_enhancement.layers[-1].output_shape[2] + img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1] + img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) if img.shape[1] < img_width_model: @@ -160,7 +142,7 @@ class Enhancer: index_y_d = img_h - img_height_model img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) + label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose=0) seg = label_p_pred[0, :, :, :] * 255 if i == 0 and j == 0: @@ -246,7 +228,7 @@ class Enhancer: else: img = self.imread() img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.model_page) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page')) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -291,7 +273,7 @@ class Enhancer: self.logger.info("Detected %s DPI", dpi) if self.input_binary: img = self.imread() - prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5) + prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) img= np.copy(prediction_bin) @@ -332,7 +314,7 @@ class Enhancer: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): if self.input_binary: @@ -352,7 +334,7 @@ class Enhancer: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 if num_col > self.num_col_upper: diff --git a/src/eynollah/mb_ro_on_layout.py b/src/eynollah/mb_ro_on_layout.py index 1b991ae..8338d35 100644 --- a/src/eynollah/mb_ro_on_layout.py +++ b/src/eynollah/mb_ro_on_layout.py @@ -10,12 +10,13 @@ from pathlib import Path import xml.etree.ElementTree as ET import cv2 +from keras.models import Model import numpy as np from ocrd_utils import getLogger import statistics import tensorflow as tf -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.contour import ( find_new_features_of_contours, @@ -23,7 +24,6 @@ from .utils.contour import ( return_parent_contours, ) from .utils import is_xml_filename -from .eynollah import PatchEncoder, Patches DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) @@ -45,21 +45,11 @@ class machine_based_reading_order_on_layout: except: self.logger.warning("no GPU device available") - self.model_reading_order = self.our_load_model(self.model_reading_order_dir) + self.model_zoo = EynollahModelZoo(basedir=dir_models) + self.model_zoo.load_model('reading_order') + # FIXME: light_version is always true, no need for checks in the code self.light_version = True - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model - def read_xml(self, xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() @@ -69,6 +59,7 @@ class machine_based_reading_order_on_layout: index_tot_regions = [] tot_region_ref = [] + y_len, x_len = 0, 0 for jj in root1.iter(link+'Page'): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) @@ -81,13 +72,13 @@ class machine_based_reading_order_on_layout: co_printspace = [] if link+'PrintSpace' in alltags: region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) - elif link+'Border' in alltags: + else: region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')]) for tag in region_tags_printspace: if link+'PrintSpace' in alltags: tag_endings_printspace = ['}PrintSpace','}printspace'] - elif link+'Border' in alltags: + else: tag_endings_printspace = ['}Border','}border'] if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]): @@ -683,7 +674,7 @@ class machine_based_reading_order_on_layout: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_reading_order.predict(input_1 , verbose=0) + y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0') for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) @@ -802,6 +793,7 @@ class machine_based_reading_order_on_layout: alltags=[elem.tag for elem in root_xml.iter()] ET.register_namespace("",name_space) + assert dir_out tree_xml.write(os.path.join(dir_out, file_name+'.xml'), xml_declaration=True, method='xml', diff --git a/src/eynollah/model_zoo/__init__.py b/src/eynollah/model_zoo/__init__.py new file mode 100644 index 0000000..e1dc985 --- /dev/null +++ b/src/eynollah/model_zoo/__init__.py @@ -0,0 +1,4 @@ +__all__ = [ + 'EynollahModelZoo', +] +from .model_zoo import EynollahModelZoo diff --git a/src/eynollah/model_zoo/default_specs.py b/src/eynollah/model_zoo/default_specs.py new file mode 100644 index 0000000..e06c829 --- /dev/null +++ b/src/eynollah/model_zoo/default_specs.py @@ -0,0 +1,314 @@ +from .specs import EynollahModelSpec, EynollahModelSpecSet +from .types import KerasModel, TrOCRProcessor, List + +# NOTE: This needs to change whenever models/versions change +ZENODO = "https://zenodo.org/records/17295988/files" +MODELS_VERSION = "v0_7_0" + +def dist_url(dist_name: str) -> str: + return f'{ZENODO}/models_{dist_name}_{MODELS_VERSION}.zip' + +DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ + + EynollahModelSpec( + category="enhancement", + variant='', + filename="models_eynollah/eynollah-enhancement_20210425", + dists=['enhancement', 'layout'], + dist_url=dist_url("enhancement"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization", + variant='', + filename="models_eynollah/eynollah-binarization-hybrid_20230504", + dists=['layout', 'binarization'], + dist_url=dist_url("binarization"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization", + variant='20210309', + filename="models_eynollah/eynollah-binarization_20210309", + dists=['binarization'], + dist_url=dist_url("binarization"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization", + variant='augment', + filename="models_eynollah/eynollah-binarization_20210425", + dists=['binarization'], + dist_url=dist_url("binarization"), + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_1", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1", + dist_url=dist_url("binarization"), + dists=['binarization'], + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_2", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2", + dist_url=dist_url("binarization"), + dists=['binarization'], + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_3", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3", + dist_url=dist_url("binarization"), + dists=['binarization'], + type=KerasModel, + ), + + EynollahModelSpec( + category="binarization_multi_4", + variant='', + filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4", + dist_url=dist_url("binarization"), + dists=['binarization'], + type=KerasModel, + ), + + EynollahModelSpec( + category="col_classifier", + variant='', + filename="models_eynollah/eynollah-column-classifier_20210425", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="page", + variant='', + filename="models_eynollah/model_eynollah_page_extraction_20250915", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region", + variant='', + filename="models_eynollah/eynollah-main-regions-ensembled_20210425", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region", + variant='extract_only_images', + filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region", + variant='light', + filename="models_eynollah/eynollah-main-regions_20220314", + dist_url=dist_url("layout"), + help="early layout", + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region_p2", + variant='', + filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425", + dist_url=dist_url("layout"), + help="early layout, non-light, 2nd part", + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="region_1_2", + variant='', + #filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n", + #filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8", + #filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18", + #filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige", + #filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige", + filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024", + dist_url=dist_url("layout"), + dists=['layout'], + help="early layout, light, 1-or-2-column", + type=KerasModel, + ), + + EynollahModelSpec( + category="region_fl_np", + variant='', + #'filename="models_eynollah/modelens_full_lay_1_3_031124", + #'filename="models_eynollah/modelens_full_lay_13__3_19_241024", + #'filename="models_eynollah/model_full_lay_13_241024", + #'filename="models_eynollah/modelens_full_lay_13_17_231024", + #'filename="models_eynollah/modelens_full_lay_1_2_221024", + #'filename="models_eynollah/eynollah-full-regions-1column_20210425", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist_url=dist_url("layout"), + help="full layout / no patches", + dists=['layout'], + type=KerasModel, + ), + + # FIXME: Why is region_fl and region_fl_np the same model? + EynollahModelSpec( + category="region_fl", + variant='', + # filename="models_eynollah/eynollah-full-regions-3+column_20210425", + # filename="models_eynollah/model_2_full_layout_new_trans", + # filename="models_eynollah/modelens_full_lay_1_3_031124", + # filename="models_eynollah/modelens_full_lay_13__3_19_241024", + # filename="models_eynollah/model_full_lay_13_241024", + # filename="models_eynollah/modelens_full_lay_13_17_231024", + # filename="models_eynollah/modelens_full_lay_1_2_221024", + # filename="models_eynollah/modelens_full_layout_24_till_28", + # filename="models_eynollah/model_2_full_layout_new_trans", + filename="models_eynollah/modelens_full_lay_1__4_3_091124", + dist_url=dist_url("layout"), + help="full layout / with patches", + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="reading_order", + variant='', + #filename="models_eynollah/model_mb_ro_aug_ens_11", + #filename="models_eynollah/model_step_3200000_mb_ro", + #filename="models_eynollah/model_ens_reading_order_machine_based", + #filename="models_eynollah/model_mb_ro_aug_ens_8", + #filename="models_eynollah/model_ens_reading_order_machine_based", + filename="models_eynollah/model_eynollah_reading_order_20250824", + dist_url=dist_url("reading_order"), + dists=['layout', 'reading_order'], + type=KerasModel, + ), + + EynollahModelSpec( + category="textline", + variant='', + #filename="models_eynollah/modelens_textline_1_4_16092024", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_1_3_4_20240915", + #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", + #filename="models_eynollah/modelens_textline_9_12_13_14_15", + #filename="models_eynollah/eynollah-textline_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="textline", + variant='light', + #filename="models_eynollah/eynollah-textline_light_20210425", + filename="models_eynollah/modelens_textline_0_1__2_4_16092024", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="table", + variant='', + filename="models_eynollah/eynollah-tables_20210319", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="table", + variant='light', + filename="models_eynollah/modelens_table_0t4_201124", + dist_url=dist_url("layout"), + dists=['layout'], + type=KerasModel, + ), + + EynollahModelSpec( + category="ocr", + variant='', + filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", + dist_url=dist_url("ocr"), + dists=['layout', 'ocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="ocr", + variant='degraded', + filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/", + help="slightly better at degraded Fraktur", + dist_url=dist_url("ocr"), + dists=['ocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="num_to_char", + variant='', + filename="characters_org.txt", + dist_url=dist_url("ocr"), + dists=['ocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="characters", + variant='', + filename="characters_org.txt", + dist_url=dist_url("ocr"), + dists=['ocr'], + type=list, + ), + + EynollahModelSpec( + category="ocr", + variant='tr', + filename="models_eynollah/model_eynollah_ocr_trocr_20250919", + dist_url=dist_url("trocr"), + help='much slower transformer-based', + dists=['trocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="trocr_processor", + variant='', + filename="models_eynollah/microsoft/trocr-base-printed", + dist_url=dist_url("trocr"), + dists=['trocr'], + type=KerasModel, + ), + + EynollahModelSpec( + category="trocr_processor", + variant='htr', + filename="models_eynollah/microsoft/trocr-base-handwritten", + dist_url=dist_url("trocr"), + dists=['trocr'], + type=TrOCRProcessor, + ), + +]) diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py new file mode 100644 index 0000000..8948a1f --- /dev/null +++ b/src/eynollah/model_zoo/model_zoo.py @@ -0,0 +1,189 @@ +import json +import logging +from copy import deepcopy +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Type, Union + +from keras.layers import StringLookup +from keras.models import Model as KerasModel +from keras.models import load_model +from tabulate import tabulate +from transformers import TrOCRProcessor, VisionEncoderDecoderModel + +from ..patch_encoder import PatchEncoder, Patches +from .specs import EynollahModelSpecSet +from .default_specs import DEFAULT_MODEL_SPECS +from .types import AnyModel, T + + +class EynollahModelZoo: + """ + Wrapper class that handles storage and loading of models for all eynollah runners. + """ + + model_basedir: Path + specs: EynollahModelSpecSet + + def __init__( + self, + basedir: str, + model_overrides: Optional[List[Tuple[str, str, str]]] = None, + ) -> None: + self.model_basedir = Path(basedir) + self.logger = logging.getLogger('eynollah.model_zoo') + self.specs = deepcopy(DEFAULT_MODEL_SPECS) + if model_overrides: + self.override_models(*model_overrides) + self._loaded: Dict[str, AnyModel] = {} + + def override_models( + self, + *model_overrides: Tuple[str, str, str], + ): + """ + Override the default model versions + """ + for model_category, model_variant, model_filename in model_overrides: + spec = self.specs.get(model_category, model_variant) + self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename) + self.specs.get(model_category, model_variant).filename = model_filename + + def model_path( + self, + model_category: str, + model_variant: str = '', + absolute: bool = True, + ) -> Path: + """ + Translate model_{type,variant} tuple into an absolute (or relative) Path + """ + spec = self.specs.get(model_category, model_variant) + if spec.category in ('characters', 'num_to_char'): + return self.model_path('ocr') / spec.filename + if not Path(spec.filename).is_absolute() and absolute: + model_path = Path(self.model_basedir).joinpath(spec.filename) + else: + model_path = Path(spec.filename) + return model_path + + def load_models( + self, + *all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]], + ) -> Dict: + """ + Load all models by calling load_model and return a dictionary mapping model_category to loaded model + """ + ret = {} + for load_args in all_load_args: + if isinstance(load_args, str): + ret[load_args] = self.load_model(load_args) + else: + ret[load_args[0]] = self.load_model(*load_args) + return ret + + def load_model( + self, + model_category: str, + model_variant: str = '', + ) -> AnyModel: + """ + Load any model + """ + model_path = self.model_path(model_category, model_variant) + if model_path.suffix == '.h5' and Path(model_path.stem).exists(): + # prefer SavedModel over HDF5 format if it exists + model_path = Path(model_path.stem) + if model_category == 'ocr': + model = self._load_ocr_model(variant=model_variant) + elif model_category == 'num_to_char': + model = self._load_num_to_char() + elif model_category == 'characters': + model = self._load_characters() + elif model_category == 'trocr_processor': + return TrOCRProcessor.from_pretrained(self.model_path(...)) + else: + try: + model = load_model(model_path, compile=False) + except Exception as e: + self.logger.exception(e) + model = load_model( + model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches} + ) + self._loaded[model_category] = model + return model # type: ignore + + def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T: + if model_category not in self._loaded: + raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"') + ret = self._loaded[model_category] + if model_type: + assert isinstance(ret, model_type) + return ret # type: ignore # FIXME: convince typing that we're returning generic type + + def _load_ocr_model(self, variant: str) -> AnyModel: + """ + Load OCR model + """ + ocr_model_dir = self.model_path('ocr', variant) + if variant == 'tr': + return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) + else: + ocr_model = load_model(ocr_model_dir, compile=False) + assert isinstance(ocr_model, KerasModel) + return KerasModel( + ocr_model.get_layer(name="image").input, # type: ignore + ocr_model.get_layer(name="dense2").output, # type: ignore + ) + + def _load_characters(self) -> List[str]: + """ + Load encoding for OCR + """ + with open(self.model_path('num_to_char'), "r") as config_file: + return json.load(config_file) + + def _load_num_to_char(self) -> StringLookup: + """ + Load decoder for OCR + """ + characters = self._load_characters() + # Mapping characters to integers. + char_to_num = StringLookup(vocabulary=characters, mask_token=None) + # Mapping integers back to original characters. + return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True) + + def __str__(self): + return tabulate( + [ + [ + spec.type.__name__, + spec.category, + spec.variant, + spec.help, + ', '.join(spec.dists), + f'Yes, at {self.model_path(spec.category, spec.variant)}' + if self.model_path(spec.category, spec.variant).exists() + else f'No, download {spec.dist_url}', + # self.model_path(spec.category, spec.variant), + ] + for spec in self.specs.specs + ], + headers=[ + 'Type', + 'Category', + 'Variant', + 'Help', + 'Used in', + 'Installed', + ], + tablefmt='github', + ) + + def shutdown(self): + """ + Ensure that a loaded models is not referenced by ``self._loaded`` anymore + """ + if hasattr(self, '_loaded') and getattr(self, '_loaded'): + for needle in self._loaded: + if self._loaded[needle]: + del self._loaded[needle] diff --git a/src/eynollah/model_zoo/specs.py b/src/eynollah/model_zoo/specs.py new file mode 100644 index 0000000..322afa4 --- /dev/null +++ b/src/eynollah/model_zoo/specs.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import Dict, List, Set, Tuple, Type +from .types import AnyModel + + +@dataclass +class EynollahModelSpec(): + """ + Describing a single model abstractly. + """ + category: str + # Relative filename to the models_eynollah directory in the dists + filename: str + # basename of the ZIP files that should contain this model + dists: List[str] + # URL to the smallest model distribution containing this model (link to Zenodo) + dist_url: str + type: Type[AnyModel] + variant: str = '' + help: str = '' + +class EynollahModelSpecSet(): + """ + List of all used models for eynollah. + """ + specs: List[EynollahModelSpec] + + def __init__(self, specs: List[EynollahModelSpec]) -> None: + self.specs = sorted(specs, key=lambda x: x.category + '0' + x.variant) + self.categories: Set[str] = set([spec.category for spec in self.specs]) + self.variants: Dict[str, Set[str]] = { + spec.category: set([x.variant for x in self.specs if x.category == spec.category]) + for spec in self.specs + } + self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = { + (spec.category, spec.variant): spec + for spec in self.specs + } + + def asdict(self) -> Dict[str, Dict[str, str]]: + return { + spec.category: { + spec.variant: spec.filename + } + for spec in self.specs + } + + def get(self, category: str, variant: str) -> EynollahModelSpec: + if category not in self.categories: + raise ValueError(f"Unknown category '{category}', must be one of {self.categories}") + if variant not in self.variants[category]: + raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}") + return self._index_category_variant[(category, variant)] + + diff --git a/src/eynollah/model_zoo/types.py b/src/eynollah/model_zoo/types.py new file mode 100644 index 0000000..5c3685e --- /dev/null +++ b/src/eynollah/model_zoo/types.py @@ -0,0 +1,6 @@ +from typing import List, TypeVar, Union +from keras.models import Model as KerasModel +from transformers import TrOCRProcessor, VisionEncoderDecoderModel + +AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List] +T = TypeVar('T') diff --git a/src/eynollah/ocrd-tool.json b/src/eynollah/ocrd-tool.json index dbbdc3b..3d1193d 100644 --- a/src/eynollah/ocrd-tool.json +++ b/src/eynollah/ocrd-tool.json @@ -83,10 +83,10 @@ }, "resources": [ { - "url": "https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1", - "name": "models_layout_v0_5_0", + "url": "https://zenodo.org/records/17295988/files/models_layout_v0_6_0.tar.gz?download=1", + "name": "models_layout_v0_6_0", "type": "archive", - "path_in_archive": "models_layout_v0_5_0", + "path_in_archive": "models_layout_v0_6_0", "size": 3525684179, "description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement", "version_range": ">= v0.5.0" diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py new file mode 100644 index 0000000..939ad7b --- /dev/null +++ b/src/eynollah/patch_encoder.py @@ -0,0 +1,52 @@ +from keras import layers +import tensorflow as tf + +projection_dim = 64 +patch_size = 1 +num_patches =21*21#14*14#28*28#14*14#28*28 + +class PatchEncoder(layers.Layer): + + def __init__(self): + super().__init__() + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim) + + def call(self, patch): + positions = tf.range(start=0, limit=num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + + def get_config(self): + config = super().get_config().copy() + config.update({ + 'num_patches': num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + +class Patches(layers.Layer): + def __init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size = patch_size + + def call(self, images): + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size': self.patch_size, + }) + return config diff --git a/src/eynollah/plot.py b/src/eynollah/plot.py index c026e94..b1b2359 100644 --- a/src/eynollah/plot.py +++ b/src/eynollah/plot.py @@ -40,8 +40,8 @@ class EynollahPlotter: self.image_filename_stem = image_filename_stem # XXX TODO hacky these cannot be set at init time self.image_org = image_org - self.scale_x = scale_x - self.scale_y = scale_y + self.scale_x : float = scale_x + self.scale_y : float = scale_y def save_plot_of_layout_main(self, text_regions_p, image_page): if self.dir_of_layout is not None: diff --git a/src/eynollah/sbb_binarize.py b/src/eynollah/sbb_binarize.py index 3716987..48dc7b1 100644 --- a/src/eynollah/sbb_binarize.py +++ b/src/eynollah/sbb_binarize.py @@ -2,18 +2,19 @@ Tool to load model and binarize a given image. """ -import sys -from glob import glob import os import logging +from pathlib import Path +from typing import Dict, List +from keras.models import Model import numpy as np -from PIL import Image import cv2 from ocrd_utils import tf_disable_interactive_logs + +from eynollah.model_zoo import EynollahModelZoo tf_disable_interactive_logs() import tensorflow as tf -from tensorflow.keras.models import load_model from tensorflow.python.keras import backend as tensorflow_backend from .utils import is_image_filename @@ -23,40 +24,37 @@ def resize_image(img_in, input_height, input_width): class SbbBinarizer: - def __init__(self, model_dir, logger=None): - self.model_dir = model_dir + def __init__(self, model_dir: str, mode: str, logger=None): + if mode not in ('single', 'multi'): + raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}") self.log = logger if logger else logging.getLogger('SbbBinarizer') - - self.start_new_session() - - self.model_files = glob(self.model_dir+"/*/", recursive = True) - - self.models = [] - for model_file in self.model_files: - self.models.append(self.load_model(model_file)) + self.model_zoo = EynollahModelZoo(basedir=model_dir) + self.models = self.setup_models(mode) + self.session = self.start_new_session() def start_new_session(self): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() - tensorflow_backend.set_session(self.session) + session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + tensorflow_backend.set_session(session) + return session + + def setup_models(self, mode: str) -> Dict[Path, Model]: + return { + self.model_zoo.model_path(v): self.model_zoo.load_model(v) + for v in (['binarization'] if mode == 'single' else [f'binarization_multi_{i}' for i in range(1, 5)]) + } def end_session(self): tensorflow_backend.clear_session() self.session.close() del self.session - def load_model(self, model_name): - model = load_model(os.path.join(self.model_dir, model_name), compile=False) + def predict(self, img, use_patches, n_batch_inference=5): + model = self.model_zoo.get('binarization', Model) model_height = model.layers[len(model.layers)-1].output_shape[1] model_width = model.layers[len(model.layers)-1].output_shape[2] - n_classes = model.layers[len(model.layers)-1].output_shape[3] - return model, model_height, model_width, n_classes - - def predict(self, model_in, img, use_patches, n_batch_inference=5): - tensorflow_backend.set_session(self.session) - model, model_height, model_width, n_classes = model_in img_org_h = img.shape[0] img_org_w = img.shape[1] @@ -324,8 +322,8 @@ class SbbBinarizer: if image_path is not None: image = cv2.imread(image_path) img_last = 0 - for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): - self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) + for n, (model_file, model) in enumerate(self.models.items()): + self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys()))) res = self.predict(model, image, use_patches) @@ -354,8 +352,8 @@ class SbbBinarizer: print(image_name,'image_name') image = cv2.imread(os.path.join(dir_in,image_name) ) img_last = 0 - for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): - self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) + for n, (model_file, model) in enumerate(self.models.items()): + self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys()))) res = self.predict(model, image, use_patches) diff --git a/src/eynollah/utils/__init__.py b/src/eynollah/utils/__init__.py index 5ccb2af..94f6983 100644 --- a/src/eynollah/utils/__init__.py +++ b/src/eynollah/utils/__init__.py @@ -393,7 +393,12 @@ def find_num_col_deskew(regions_without_separators, sigma_, multiplier=3.8): z = gaussian_filter1d(regions_without_separators_0, sigma_) return np.std(z) -def find_num_col(regions_without_separators, num_col_classifier, tables, multiplier=3.8): +def find_num_col( + regions_without_separators, + num_col_classifier, + tables, + multiplier=3.8, +): if not regions_without_separators.any(): return 0, [] #plt.imshow(regions_without_separators) diff --git a/src/eynollah/writer.py b/src/eynollah/writer.py index 9c3456a..a0ec077 100644 --- a/src/eynollah/writer.py +++ b/src/eynollah/writer.py @@ -2,7 +2,7 @@ # pylint: disable=import-error from pathlib import Path import os.path -import xml.etree.ElementTree as ET +from typing import Optional from .utils.xml import create_page_xml, xml_reading_order from .utils.counter import EynollahIdCounter @@ -10,7 +10,6 @@ from ocrd_utils import getLogger from ocrd_models.ocrd_page import ( BorderType, CoordsType, - PcGtsType, TextLineType, TextEquivType, TextRegionType, @@ -32,10 +31,10 @@ class EynollahXmlWriter: self.curved_line = curved_line self.textline_light = textline_light self.pcgts = pcgts - self.scale_x = None # XXX set outside __init__ - self.scale_y = None # XXX set outside __init__ - self.height_org = None # XXX set outside __init__ - self.width_org = None # XXX set outside __init__ + self.scale_x: Optional[float] = None # XXX set outside __init__ + self.scale_y: Optional[float] = None # XXX set outside __init__ + self.height_org: Optional[int] = None # XXX set outside __init__ + self.width_org: Optional[int] = None # XXX set outside __init__ @property def image_filename_stem(self): @@ -135,6 +134,7 @@ class EynollahXmlWriter: # create the file structure pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org) page = pcgts.get_Page() + assert page page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_page_coords(cont_page)))) counter = EynollahIdCounter() @@ -152,6 +152,7 @@ class EynollahXmlWriter: Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord, skip_layout_reading_order)) ) + assert textregion.Coords if conf_contours_textregions: textregion.Coords.set_conf(conf_contours_textregions[mm]) page.add_TextRegion(textregion) @@ -168,6 +169,7 @@ class EynollahXmlWriter: id=counter.next_region_id, type_='heading', Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord)) ) + assert textregion.Coords if conf_contours_textregions_h: textregion.Coords.set_conf(conf_contours_textregions_h[mm]) page.add_TextRegion(textregion) diff --git a/tests/test_run.py b/tests/test_run.py index 79c64c2..a410d34 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -16,10 +16,13 @@ from ocrd_models.constants import NAMESPACES as NS testdir = Path(__file__).parent.resolve() -MODELS_LAYOUT = environ.get('MODELS_LAYOUT', str(testdir.joinpath('..', 'models_layout_v0_5_0').resolve())) -MODELS_OCR = environ.get('MODELS_OCR', str(testdir.joinpath('..', 'models_ocr_v0_5_1').resolve())) +MODELS_LAYOUT = environ.get('MODELS_LAYOUT', str(testdir.joinpath('..', 'models_layout_v0_6_0').resolve())) +MODELS_OCR = environ.get('MODELS_OCR', str(testdir.joinpath('..', 'models_ocr_v0_6_0').resolve())) MODELS_BIN = environ.get('MODELS_BIN', str(testdir.joinpath('..', 'default-2021-03-09').resolve())) +def only_eynollah(logrec): + return logrec.name.startswith('eynollah') + @pytest.mark.parametrize( "options", [ @@ -50,8 +53,6 @@ def test_run_eynollah_layout_filename(tmp_path, pytestconfig, caplog, options): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(layout_cli, args + options, catch_exceptions=False) @@ -85,8 +86,6 @@ def test_run_eynollah_layout_filename2(tmp_path, pytestconfig, caplog, options): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(layout_cli, args + options, catch_exceptions=False) @@ -116,8 +115,6 @@ def test_run_eynollah_layout_directory(tmp_path, pytestconfig, caplog): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(layout_cli, args, catch_exceptions=False) @@ -144,8 +141,6 @@ def test_run_eynollah_binarization_filename(tmp_path, pytestconfig, caplog, opti if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'SbbBinarizer' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(binarization_cli, args + options, catch_exceptions=False) @@ -170,8 +165,6 @@ def test_run_eynollah_binarization_directory(tmp_path, pytestconfig, caplog): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'SbbBinarizer' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(binarization_cli, args, catch_exceptions=False) @@ -197,8 +190,6 @@ def test_run_eynollah_enhancement_filename(tmp_path, pytestconfig, caplog, optio if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'enhancement' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(enhancement_cli, args + options, catch_exceptions=False) @@ -223,8 +214,6 @@ def test_run_eynollah_enhancement_directory(tmp_path, pytestconfig, caplog): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'enhancement' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(enhancement_cli, args, catch_exceptions=False) @@ -244,8 +233,6 @@ def test_run_eynollah_mbreorder_filename(tmp_path, pytestconfig, caplog): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'mbreorder' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(mbreorder_cli, args, catch_exceptions=False) @@ -273,8 +260,6 @@ def test_run_eynollah_mbreorder_directory(tmp_path, pytestconfig, caplog): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'mbreorder' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(mbreorder_cli, args, catch_exceptions=False) @@ -306,8 +291,6 @@ def test_run_eynollah_ocr_filename(tmp_path, pytestconfig, caplog, options): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.DEBUG) - def only_eynollah(logrec): - return logrec.name == 'eynollah' runner = CliRunner() if "-doit" in options: options.insert(options.index("-doit") + 1, str(outrenderfile.parent)) @@ -339,8 +322,6 @@ def test_run_eynollah_ocr_directory(tmp_path, pytestconfig, caplog): if pytestconfig.getoption('verbose') > 0: args.extend(['-l', 'DEBUG']) caplog.set_level(logging.INFO) - def only_eynollah(logrec): - return logrec.name == 'eynollah' runner = CliRunner() with caplog.filtering(only_eynollah): result = runner.invoke(ocr_cli, args, catch_exceptions=False) diff --git a/train/README.md b/train/README.md index 5f6d326..6aeea5d 100644 --- a/train/README.md +++ b/train/README.md @@ -22,14 +22,14 @@ Download our pretrained weights and add them to a `train/pretrained_model` folde ```sh cd train -wget -O pretrained_model.tar.gz https://zenodo.org/records/17243320/files/pretrained_model_v0_5_1.tar.gz?download=1 +wget -O pretrained_model.tar.gz "https://zenodo.org/records/17295988/files/pretrained_model_v0_6_0.tar.gz?download=1" tar xf pretrained_model.tar.gz ``` ### Binarization training data A small sample of training data for binarization experiment can be found [on -zenodo](https://zenodo.org/records/17243320/files/training_data_sample_binarization_v0_5_1.tar.gz?download=1), +zenodo](https://zenodo.org/records/17295988/files/training_data_sample_binarization_v0_6_0.tar.gz?download=1), which contains `images` and `labels` folders. ### Helpful tools