mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
Merge 883546a6b8 into 38c028c6b5
This commit is contained in:
commit
7cf6ae1d7a
22 changed files with 1962 additions and 1341 deletions
19
Makefile
19
Makefile
|
|
@ -6,21 +6,23 @@ EXTRAS ?=
|
|||
DOCKER_BASE_IMAGE ?= docker.io/ocrd/core-cuda-tf2:latest
|
||||
DOCKER_TAG ?= ocrd/eynollah
|
||||
DOCKER ?= docker
|
||||
WGET = wget -O
|
||||
|
||||
#SEG_MODEL := https://qurator-data.de/eynollah/2021-04-25/models_eynollah.tar.gz
|
||||
#SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah_renamed.tar.gz
|
||||
# SEG_MODEL := https://qurator-data.de/eynollah/2022-04-05/models_eynollah.tar.gz
|
||||
#SEG_MODEL := https://github.com/qurator-spk/eynollah/releases/download/v0.3.0/models_eynollah.tar.gz
|
||||
#SEG_MODEL := https://github.com/qurator-spk/eynollah/releases/download/v0.3.1/models_eynollah.tar.gz
|
||||
SEG_MODEL := https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1
|
||||
#SEG_MODEL := https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1
|
||||
SEG_MODEL := https://zenodo.org/records/17295988/files/models_layout_v0_6_0.tar.gz?download=1
|
||||
SEG_MODELFILE = $(notdir $(patsubst %?download=1,%,$(SEG_MODEL)))
|
||||
SEG_MODELNAME = $(SEG_MODELFILE:%.tar.gz=%)
|
||||
|
||||
BIN_MODEL := https://github.com/qurator-spk/sbb_binarization/releases/download/v0.0.11/saved_model_2021_03_09.zip
|
||||
BIN_MODEL := https://zenodo.org/records/17295988/files/models_binarization_v0_6_0.tar.gz?download=1
|
||||
BIN_MODELFILE = $(notdir $(BIN_MODEL))
|
||||
BIN_MODELNAME := default-2021-03-09
|
||||
|
||||
OCR_MODEL := https://zenodo.org/records/17236998/files/models_ocr_v0_5_1.tar.gz?download=1
|
||||
OCR_MODEL := https://zenodo.org/records/17295988/files/models_ocr_v0_6_0.tar.gz?download=1
|
||||
OCR_MODELFILE = $(notdir $(patsubst %?download=1,%,$(OCR_MODEL)))
|
||||
OCR_MODELNAME = $(OCR_MODELFILE:%.tar.gz=%)
|
||||
|
||||
|
|
@ -55,22 +57,21 @@ help:
|
|||
# END-EVAL
|
||||
|
||||
|
||||
# Download and extract models to $(PWD)/models_layout_v0_5_0
|
||||
# Download and extract models to $(PWD)/models_layout_v0_6_0
|
||||
models: $(BIN_MODELNAME) $(SEG_MODELNAME) $(OCR_MODELNAME)
|
||||
|
||||
# do not download these files if we already have the directories
|
||||
.INTERMEDIATE: $(BIN_MODELFILE) $(SEG_MODELFILE) $(OCR_MODELFILE)
|
||||
|
||||
$(BIN_MODELFILE):
|
||||
wget -O $@ $(BIN_MODEL)
|
||||
$(WGET) $@ $(BIN_MODEL)
|
||||
$(SEG_MODELFILE):
|
||||
wget -O $@ $(SEG_MODEL)
|
||||
$(WGET) $@ $(SEG_MODEL)
|
||||
$(OCR_MODELFILE):
|
||||
wget -O $@ $(OCR_MODEL)
|
||||
$(WGET) $@ $(OCR_MODEL)
|
||||
|
||||
$(BIN_MODELNAME): $(BIN_MODELFILE)
|
||||
mkdir $@
|
||||
unzip -d $@ $<
|
||||
tar zxf $<
|
||||
$(SEG_MODELNAME): $(SEG_MODELFILE)
|
||||
tar zxf $<
|
||||
$(OCR_MODELNAME): $(OCR_MODELFILE)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ make install EXTRAS=OCR
|
|||
|
||||
## Models
|
||||
|
||||
Pretrained models can be downloaded from [zenodo](https://zenodo.org/records/17194824) or [huggingface](https://huggingface.co/SBB?search_models=eynollah).
|
||||
Pretrained models can be downloaded from [zenodo](https://doi.org/10.5281/zenodo.17194823) or [huggingface](https://huggingface.co/SBB?search_models=eynollah).
|
||||
|
||||
For documentation on models, have a look at [`models.md`](https://github.com/qurator-spk/eynollah/tree/main/docs/models.md).
|
||||
Model cards are also provided for our trained models.
|
||||
|
|
@ -162,7 +162,7 @@ formally described in [`ocrd-tool.json`](https://github.com/qurator-spk/eynollah
|
|||
|
||||
In this case, the source image file group with (preferably) RGB images should be used as input like this:
|
||||
|
||||
ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models eynollah_layout_v0_5_0
|
||||
ocrd-eynollah-segment -I OCR-D-IMG -O OCR-D-SEG -P models eynollah_layout_v0_6_0
|
||||
|
||||
If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynollah behaves as follows:
|
||||
- existing regions are kept and ignored (i.e. in effect they might overlap segments from Eynollah results)
|
||||
|
|
@ -174,7 +174,7 @@ If the input file group is PAGE-XML (from a previous OCR-D workflow step), Eynol
|
|||
(because some other preprocessing step was in effect like `denoised`), then
|
||||
the output PAGE-XML will be based on that as new top-level (`@imageFilename`)
|
||||
|
||||
ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_5_0
|
||||
ocrd-eynollah-segment -I OCR-D-XYZ -O OCR-D-SEG -P models eynollah_layout_v0_6_0
|
||||
|
||||
In general, it makes more sense to add other workflow steps **after** Eynollah.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,3 +6,4 @@ tensorflow < 2.13
|
|||
numba <= 0.58.1
|
||||
scikit-image
|
||||
biopython
|
||||
tabulate
|
||||
|
|
|
|||
|
|
@ -1,16 +1,24 @@
|
|||
from dataclasses import dataclass
|
||||
import sys
|
||||
import click
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
from ocrd_utils import initLogging, getLevelName, getLogger
|
||||
from eynollah.eynollah import Eynollah, Eynollah_ocr
|
||||
from eynollah.eynollah import Eynollah
|
||||
from eynollah.eynollah_ocr import Eynollah_ocr
|
||||
from eynollah.sbb_binarize import SbbBinarizer
|
||||
from eynollah.image_enhancer import Enhancer
|
||||
from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout
|
||||
from eynollah.model_zoo import EynollahModelZoo
|
||||
|
||||
from .cli_models import models_cli
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
pass
|
||||
|
||||
main.add_command(models_cli, 'models')
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"--input",
|
||||
|
|
@ -79,18 +87,38 @@ def machine_based_reading_order(input, dir_in, out, model, log_level):
|
|||
type=click.Path(file_okay=True, dir_okay=True),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
'-M',
|
||||
'--mode',
|
||||
type=click.Choice(['single', 'multi']),
|
||||
default='single',
|
||||
help="Whether to use the (faster) single-model binarization or the (slightly better) multi-model binarization"
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
"-l",
|
||||
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
|
||||
help="Override log level globally to this",
|
||||
)
|
||||
def binarization(patches, model_dir, input_image, dir_in, output, log_level):
|
||||
def binarization(
|
||||
patches,
|
||||
model_dir,
|
||||
input_image,
|
||||
mode,
|
||||
dir_in,
|
||||
output,
|
||||
log_level,
|
||||
):
|
||||
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
binarizer = SbbBinarizer(model_dir)
|
||||
binarizer = SbbBinarizer(model_dir, mode=mode)
|
||||
if log_level:
|
||||
binarizer.log.setLevel(getLevelName(log_level))
|
||||
binarizer.run(image_path=input_image, use_patches=patches, output=output, dir_in=dir_in)
|
||||
binarizer.run(
|
||||
image_path=input_image,
|
||||
use_patches=patches,
|
||||
output=output,
|
||||
dir_in=dir_in
|
||||
)
|
||||
|
||||
|
||||
@main.command()
|
||||
|
|
@ -198,15 +226,17 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
|
|||
@click.option(
|
||||
"--model",
|
||||
"-m",
|
||||
'model_basedir',
|
||||
help="directory of models",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--model_version",
|
||||
"-mv",
|
||||
help="override default versions of model categories",
|
||||
type=(str, str),
|
||||
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
|
||||
type=(str, str, str),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
|
|
@ -380,7 +410,43 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
|
|||
help="Setup a basic console logger",
|
||||
)
|
||||
|
||||
def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging):
|
||||
def layout(
|
||||
image,
|
||||
out,
|
||||
overwrite,
|
||||
dir_in,
|
||||
model_basedir,
|
||||
model_version,
|
||||
save_images,
|
||||
save_layout,
|
||||
save_deskewed,
|
||||
save_all,
|
||||
extract_only_images,
|
||||
save_page,
|
||||
enable_plotting,
|
||||
allow_enhancement,
|
||||
curved_line,
|
||||
textline_light,
|
||||
full_layout,
|
||||
tables,
|
||||
right2left,
|
||||
input_binary,
|
||||
allow_scaling,
|
||||
headers_off,
|
||||
light_version,
|
||||
reading_order_machine_based,
|
||||
do_ocr,
|
||||
transformer_ocr,
|
||||
batch_size_ocr,
|
||||
num_col_upper,
|
||||
num_col_lower,
|
||||
threshold_art_class_textline,
|
||||
threshold_art_class_layout,
|
||||
skip_layout_and_reading_order,
|
||||
ignore_page_extraction,
|
||||
log_level,
|
||||
setup_logging,
|
||||
):
|
||||
if setup_logging:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
|
@ -410,8 +476,8 @@ def layout(image, out, overwrite, dir_in, model, model_version, save_images, sav
|
|||
assert not extract_only_images or not headers_off, "Image extraction -eoi can not be set alongside headers_off -ho"
|
||||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
eynollah = Eynollah(
|
||||
model,
|
||||
model_versions=model_version,
|
||||
model_basedir,
|
||||
model_overrides=model_version,
|
||||
extract_only_images=extract_only_images,
|
||||
enable_plotting=enable_plotting,
|
||||
allow_enhancement=allow_enhancement,
|
||||
|
|
|
|||
93
src/eynollah/cli_models.py
Normal file
93
src/eynollah/cli_models.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Set, Tuple
|
||||
import click
|
||||
|
||||
from eynollah.model_zoo.default_specs import MODELS_VERSION
|
||||
from .model_zoo import EynollahModelZoo
|
||||
|
||||
|
||||
@dataclass()
|
||||
class EynollahCliCtx:
|
||||
model_zoo: EynollahModelZoo
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
@click.option(
|
||||
"--model",
|
||||
"-m",
|
||||
'model_basedir',
|
||||
help="directory of models",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
# default=f"{os.environ['HOME']}/.local/share/ocrd-resources/ocrd-eynollah-segment",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--model-overrides",
|
||||
"-mv",
|
||||
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
|
||||
type=(str, str, str),
|
||||
multiple=True,
|
||||
)
|
||||
def models_cli(
|
||||
ctx,
|
||||
model_basedir: str,
|
||||
model_overrides: List[Tuple[str, str, str]],
|
||||
):
|
||||
"""
|
||||
Organize models for the various runners in eynollah.
|
||||
"""
|
||||
ctx.obj = EynollahCliCtx(model_zoo=EynollahModelZoo(basedir=model_basedir, model_overrides=model_overrides))
|
||||
|
||||
|
||||
@models_cli.command('list')
|
||||
@click.pass_context
|
||||
def list_models(
|
||||
ctx,
|
||||
):
|
||||
"""
|
||||
List all the models in the zoo
|
||||
"""
|
||||
print(ctx.obj.model_zoo)
|
||||
|
||||
|
||||
@models_cli.command('package')
|
||||
@click.option(
|
||||
'--set-version', '-V', 'version', help="Version to use for packaging", default=MODELS_VERSION, show_default=True
|
||||
)
|
||||
@click.argument('output_dir')
|
||||
@click.pass_context
|
||||
def package(
|
||||
ctx,
|
||||
version,
|
||||
output_dir,
|
||||
):
|
||||
"""
|
||||
Generate shell code to copy all the models in the zoo into properly named folders in OUTPUT_DIR for distribution.
|
||||
|
||||
eynollah models -m SRC package OUTPUT_DIR
|
||||
|
||||
SRC should contain a directory "models_eynollah" containing all the models.
|
||||
"""
|
||||
mkdirs: Set[Path] = set([])
|
||||
copies: Set[Tuple[Path, Path]] = set([])
|
||||
for spec in ctx.obj.model_zoo.specs.specs:
|
||||
# skip these as they are dependent on the ocr model
|
||||
if spec.category in ('num_to_char', 'characters'):
|
||||
continue
|
||||
src: Path = ctx.obj.model_zoo.model_path(spec.category, spec.variant)
|
||||
# Only copy the top-most directory relative to models_eynollah
|
||||
while src.parent.name != 'models_eynollah':
|
||||
src = src.parent
|
||||
for dist in spec.dists:
|
||||
dist_dir = Path(f"{output_dir}/models_{dist}_{version}/models_eynollah")
|
||||
copies.add((src, dist_dir))
|
||||
mkdirs.add(dist_dir)
|
||||
for dir in mkdirs:
|
||||
print(f"mkdir -p {dir}")
|
||||
for (src, dst) in copies:
|
||||
print(f"cp -r {src} {dst}")
|
||||
for dir in mkdirs:
|
||||
zip_path = Path(f'../{dir.parent.name}.zip')
|
||||
print(f"(cd {dir}/..; zip -r {zip_path} models_eynollah)")
|
||||
File diff suppressed because it is too large
Load diff
998
src/eynollah/eynollah_ocr.py
Normal file
998
src/eynollah/eynollah_ocr.py
Normal file
|
|
@ -0,0 +1,998 @@
|
|||
# pyright: reportPossiblyUnboundVariable=false
|
||||
|
||||
from logging import Logger, getLogger
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
import sys
|
||||
import math
|
||||
import time
|
||||
|
||||
from keras.layers import StringLookup
|
||||
import cv2
|
||||
import xml.etree.ElementTree as ET
|
||||
import tensorflow as tf
|
||||
from keras.models import load_model
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
from eynollah.model_zoo import EynollahModelZoo
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
|
||||
from .utils import is_image_filename
|
||||
from .utils.resize import resize_image
|
||||
from .utils.utils_ocr import (
|
||||
break_curved_line_into_small_pieces_and_then_merge,
|
||||
decode_batch_predictions,
|
||||
fit_text_single_line,
|
||||
get_contours_and_bounding_boxes,
|
||||
get_orientation_moments,
|
||||
preprocess_and_resize_image_for_ocrcnn_model,
|
||||
return_textlines_split_if_needed,
|
||||
rotate_image_with_padding,
|
||||
)
|
||||
|
||||
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
|
||||
if sys.version_info < (3, 10):
|
||||
import importlib_resources
|
||||
else:
|
||||
import importlib.resources as importlib_resources
|
||||
|
||||
try:
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
except ImportError:
|
||||
TrOCRProcessor = VisionEncoderDecoderModel = None
|
||||
|
||||
class Eynollah_ocr:
|
||||
def __init__(
|
||||
self,
|
||||
dir_models,
|
||||
model_name=None,
|
||||
dir_xmls=None,
|
||||
tr_ocr=False,
|
||||
batch_size: Optional[int]=None,
|
||||
export_textline_images_and_text: bool=False,
|
||||
do_not_mask_with_textline_contour: bool=False,
|
||||
pref_of_dataset=None,
|
||||
min_conf_value_of_textline_text : float=0.3,
|
||||
logger: Optional[Logger]=None,
|
||||
):
|
||||
self.tr_ocr = tr_ocr
|
||||
self.export_textline_images_and_text = export_textline_images_and_text
|
||||
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
|
||||
self.pref_of_dataset = pref_of_dataset
|
||||
self.logger = logger if logger else getLogger('eynollah')
|
||||
self.model_zoo = EynollahModelZoo(basedir=dir_models)
|
||||
|
||||
# TODO: Properly document what 'export_textline_images_and_text' is about
|
||||
if export_textline_images_and_text:
|
||||
self.logger.info("export_textline_images_and_text was set, so no actual models are loaded")
|
||||
return
|
||||
|
||||
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
|
||||
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
|
||||
|
||||
if tr_ocr:
|
||||
self.model_zoo.load_model('trocr_processor', '')
|
||||
if model_name:
|
||||
self.model_zoo.load_model('ocr', 'tr', model_name)
|
||||
else:
|
||||
self.model_zoo.load_model('ocr', 'tr')
|
||||
self.model_zoo.get('ocr').to(self.device)
|
||||
else:
|
||||
if model_name:
|
||||
self.model_zoo.load_model('ocr', '', model_name)
|
||||
else:
|
||||
self.model_zoo.load_model('ocr', '')
|
||||
self.model_zoo.load_model('num_to_char')
|
||||
self.end_character = len(self.model_zoo.load_model('characters')) + 2
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if torch.cuda.is_available():
|
||||
self.logger.info("Using GPU acceleration")
|
||||
return torch.device("cuda:0")
|
||||
else:
|
||||
self.logger.info("Using CPU processing")
|
||||
return torch.device("cpu")
|
||||
|
||||
def run(self, overwrite: bool = False,
|
||||
dir_in: Optional[str] = None,
|
||||
dir_in_bin: Optional[str] = None,
|
||||
image_filename: Optional[str] = None,
|
||||
dir_xmls: Optional[str] = None,
|
||||
dir_out_image_text: Optional[str] = None,
|
||||
dir_out: Optional[str] = None,
|
||||
):
|
||||
if dir_in:
|
||||
ls_imgs = [os.path.join(dir_in, image_filename)
|
||||
for image_filename in filter(is_image_filename,
|
||||
os.listdir(dir_in))]
|
||||
else:
|
||||
assert image_filename
|
||||
ls_imgs = [image_filename]
|
||||
|
||||
if self.tr_ocr:
|
||||
tr_ocr_input_height_and_width = 384
|
||||
for dir_img in ls_imgs:
|
||||
file_name = Path(dir_img).stem
|
||||
assert dir_xmls # FIXME: check the logic
|
||||
dir_xml = os.path.join(dir_xmls, file_name+'.xml')
|
||||
assert dir_out # FIXME: check the logic
|
||||
out_file_ocr = os.path.join(dir_out, file_name+'.xml')
|
||||
|
||||
if os.path.exists(out_file_ocr):
|
||||
if overwrite:
|
||||
self.logger.warning("will overwrite existing output file '%s'", out_file_ocr)
|
||||
else:
|
||||
self.logger.warning("will skip input for existing output file '%s'", out_file_ocr)
|
||||
continue
|
||||
|
||||
img = cv2.imread(dir_img)
|
||||
|
||||
if dir_out_image_text:
|
||||
out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png')
|
||||
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
||||
draw = ImageDraw.Draw(image_text)
|
||||
total_bb_coordinates = []
|
||||
|
||||
##file_name = Path(dir_xmls).stem
|
||||
tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8"))
|
||||
root1=tree1.getroot()
|
||||
alltags=[elem.tag for elem in root1.iter()]
|
||||
link=alltags[0].split('}')[0]+'}'
|
||||
|
||||
name_space = alltags[0].split('}')[0]
|
||||
name_space = name_space.split('{')[1]
|
||||
|
||||
region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')])
|
||||
|
||||
|
||||
|
||||
cropped_lines = []
|
||||
cropped_lines_region_indexer = []
|
||||
cropped_lines_meging_indexing = []
|
||||
|
||||
extracted_texts = []
|
||||
|
||||
indexer_text_region = 0
|
||||
indexer_b_s = 0
|
||||
|
||||
for nn in root1.iter(region_tags):
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
|
||||
for child_textlines in child_textregion:
|
||||
if child_textlines.tag.endswith("Coords"):
|
||||
cropped_lines_region_indexer.append(indexer_text_region)
|
||||
p_h=child_textlines.attrib['points'].split(' ')
|
||||
textline_coords = np.array( [ [int(x.split(',')[0]),
|
||||
int(x.split(',')[1]) ]
|
||||
for x in p_h] )
|
||||
x,y,w,h = cv2.boundingRect(textline_coords)
|
||||
|
||||
if dir_out_image_text:
|
||||
total_bb_coordinates.append([x,y,w,h])
|
||||
|
||||
h2w_ratio = h/float(w)
|
||||
|
||||
img_poly_on_img = np.copy(img)
|
||||
mask_poly = np.zeros(img.shape)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
||||
|
||||
mask_poly = mask_poly[y:y+h, x:x+w, :]
|
||||
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
|
||||
img_crop[mask_poly==0] = 255
|
||||
|
||||
self.logger.debug("processing %d lines for '%s'",
|
||||
len(cropped_lines), nn.attrib['id'])
|
||||
if h2w_ratio > 0.1:
|
||||
cropped_lines.append(resize_image(img_crop,
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width) )
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
indexer_b_s+=1
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('processor').batch_decode(
|
||||
generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
else:
|
||||
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
|
||||
#print(splited_images)
|
||||
if splited_images:
|
||||
cropped_lines.append(resize_image(splited_images[0],
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width))
|
||||
cropped_lines_meging_indexing.append(1)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('processor').batch_decode(
|
||||
generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
|
||||
cropped_lines.append(resize_image(splited_images[1],
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width))
|
||||
cropped_lines_meging_indexing.append(-1)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('processor').batch_decode(
|
||||
generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
else:
|
||||
cropped_lines.append(img_crop)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('processor').batch_decode(
|
||||
generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
|
||||
|
||||
indexer_text_region = indexer_text_region +1
|
||||
|
||||
if indexer_b_s!=0:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('processor').batch_decode(generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
####extracted_texts = []
|
||||
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
||||
|
||||
####for i in range(n_iterations):
|
||||
####if i==(n_iterations-1):
|
||||
####n_start = i*self.b_s
|
||||
####imgs = cropped_lines[n_start:]
|
||||
####else:
|
||||
####n_start = i*self.b_s
|
||||
####n_end = (i+1)*self.b_s
|
||||
####imgs = cropped_lines[n_start:n_end]
|
||||
####pixel_values_merged = self.model_zoo.get('processor')(imgs, return_tensors="pt").pixel_values
|
||||
####generated_ids_merged = self.model_ocr.generate(
|
||||
#### pixel_values_merged.to(self.device))
|
||||
####generated_text_merged = self.model_zoo.get('processor').batch_decode(
|
||||
#### generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
####extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
del cropped_lines
|
||||
gc.collect()
|
||||
|
||||
extracted_texts_merged = [extracted_texts[ind]
|
||||
if cropped_lines_meging_indexing[ind]==0
|
||||
else extracted_texts[ind]+" "+extracted_texts[ind+1]
|
||||
if cropped_lines_meging_indexing[ind]==1
|
||||
else None
|
||||
for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
||||
#print(extracted_texts_merged, len(extracted_texts_merged))
|
||||
|
||||
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
|
||||
|
||||
if dir_out_image_text:
|
||||
|
||||
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||
font = importlib_resources.files(__package__) / "Charis-Regular.ttf"
|
||||
with importlib_resources.as_file(font) as font:
|
||||
font = ImageFont.truetype(font=font, size=40)
|
||||
|
||||
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
|
||||
|
||||
|
||||
x_bb = bb_ind[0]
|
||||
y_bb = bb_ind[1]
|
||||
w_bb = bb_ind[2]
|
||||
h_bb = bb_ind[3]
|
||||
|
||||
font = fit_text_single_line(draw, extracted_texts_merged[indexer_text],
|
||||
font.path, w_bb, int(h_bb*0.4) )
|
||||
|
||||
##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2)
|
||||
|
||||
text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font)
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
|
||||
text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally
|
||||
text_y = y_bb + (h_bb - text_height) // 2 # Center vertically
|
||||
|
||||
# Draw the text
|
||||
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
|
||||
image_text.save(out_image_with_text)
|
||||
|
||||
#print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer')
|
||||
#######text_by_textregion = []
|
||||
#######for ind in unique_cropped_lines_region_indexer:
|
||||
#######ind = np.array(cropped_lines_region_indexer)==ind
|
||||
#######extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
|
||||
#######text_by_textregion.append(" ".join(extracted_texts_merged_un))
|
||||
|
||||
text_by_textregion = []
|
||||
for ind in unique_cropped_lines_region_indexer:
|
||||
ind = np.array(cropped_lines_region_indexer) == ind
|
||||
extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
|
||||
if len(extracted_texts_merged_un)>1:
|
||||
text_by_textregion_ind = ""
|
||||
next_glue = ""
|
||||
for indt in range(len(extracted_texts_merged_un)):
|
||||
if (extracted_texts_merged_un[indt].endswith('⸗') or
|
||||
extracted_texts_merged_un[indt].endswith('-') or
|
||||
extracted_texts_merged_un[indt].endswith('¬')):
|
||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1]
|
||||
next_glue = ""
|
||||
else:
|
||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt]
|
||||
next_glue = " "
|
||||
text_by_textregion.append(text_by_textregion_ind)
|
||||
else:
|
||||
text_by_textregion.append(" ".join(extracted_texts_merged_un))
|
||||
|
||||
|
||||
indexer = 0
|
||||
indexer_textregion = 0
|
||||
for nn in root1.iter(region_tags):
|
||||
#id_textregion = nn.attrib['id']
|
||||
#id_textregions.append(id_textregion)
|
||||
#textregions_by_existing_ids.append(text_by_textregion[indexer_textregion])
|
||||
|
||||
is_textregion_text = False
|
||||
for childtest in nn:
|
||||
if childtest.tag.endswith("TextEquiv"):
|
||||
is_textregion_text = True
|
||||
|
||||
if not is_textregion_text:
|
||||
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
|
||||
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
|
||||
|
||||
|
||||
has_textline = False
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
|
||||
is_textline_text = False
|
||||
for childtest2 in child_textregion:
|
||||
if childtest2.tag.endswith("TextEquiv"):
|
||||
is_textline_text = True
|
||||
|
||||
|
||||
if not is_textline_text:
|
||||
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
|
||||
##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
|
||||
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
|
||||
unicode_textline.text = extracted_texts_merged[indexer]
|
||||
else:
|
||||
for childtest3 in child_textregion:
|
||||
if childtest3.tag.endswith("TextEquiv"):
|
||||
for child_uc in childtest3:
|
||||
if child_uc.tag.endswith("Unicode"):
|
||||
##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
|
||||
child_uc.text = extracted_texts_merged[indexer]
|
||||
|
||||
indexer = indexer + 1
|
||||
has_textline = True
|
||||
if has_textline:
|
||||
if is_textregion_text:
|
||||
for child4 in nn:
|
||||
if child4.tag.endswith("TextEquiv"):
|
||||
for childtr_uc in child4:
|
||||
if childtr_uc.tag.endswith("Unicode"):
|
||||
childtr_uc.text = text_by_textregion[indexer_textregion]
|
||||
else:
|
||||
unicode_textregion.text = text_by_textregion[indexer_textregion]
|
||||
indexer_textregion = indexer_textregion + 1
|
||||
|
||||
###sample_order = [(id_to_order[tid], text)
|
||||
### for tid, text in zip(id_textregions, textregions_by_existing_ids)
|
||||
### if tid in id_to_order]
|
||||
|
||||
##ordered_texts_sample = [text for _, text in sorted(sample_order)]
|
||||
##tot_page_text = ' '.join(ordered_texts_sample)
|
||||
|
||||
##for page_element in root1.iter(link+'Page'):
|
||||
##text_page = ET.SubElement(page_element, 'TextEquiv')
|
||||
##unicode_textpage = ET.SubElement(text_page, 'Unicode')
|
||||
##unicode_textpage.text = tot_page_text
|
||||
|
||||
ET.register_namespace("",name_space)
|
||||
tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None)
|
||||
else:
|
||||
###max_len = 280#512#280#512
|
||||
###padding_token = 1500#299#1500#299
|
||||
image_width = 512#max_len * 4
|
||||
image_height = 32
|
||||
|
||||
|
||||
img_size=(image_width, image_height)
|
||||
|
||||
for dir_img in ls_imgs:
|
||||
file_name = Path(dir_img).stem
|
||||
dir_xml = os.path.join(dir_xmls, file_name+'.xml')
|
||||
out_file_ocr = os.path.join(dir_out, file_name+'.xml')
|
||||
|
||||
if os.path.exists(out_file_ocr):
|
||||
if overwrite:
|
||||
self.logger.warning("will overwrite existing output file '%s'", out_file_ocr)
|
||||
else:
|
||||
self.logger.warning("will skip input for existing output file '%s'", out_file_ocr)
|
||||
continue
|
||||
|
||||
img = cv2.imread(dir_img)
|
||||
if dir_in_bin is not None:
|
||||
cropped_lines_bin = []
|
||||
dir_img_bin = os.path.join(dir_in_bin, file_name+'.png')
|
||||
img_bin = cv2.imread(dir_img_bin)
|
||||
|
||||
if dir_out_image_text:
|
||||
out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png')
|
||||
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
||||
draw = ImageDraw.Draw(image_text)
|
||||
total_bb_coordinates = []
|
||||
|
||||
tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8"))
|
||||
root1=tree1.getroot()
|
||||
alltags=[elem.tag for elem in root1.iter()]
|
||||
link=alltags[0].split('}')[0]+'}'
|
||||
|
||||
name_space = alltags[0].split('}')[0]
|
||||
name_space = name_space.split('{')[1]
|
||||
|
||||
region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')])
|
||||
|
||||
cropped_lines = []
|
||||
cropped_lines_ver_index = []
|
||||
cropped_lines_region_indexer = []
|
||||
cropped_lines_meging_indexing = []
|
||||
|
||||
tinl = time.time()
|
||||
indexer_text_region = 0
|
||||
indexer_textlines = 0
|
||||
for nn in root1.iter(region_tags):
|
||||
try:
|
||||
type_textregion = nn.attrib['type']
|
||||
except:
|
||||
type_textregion = 'paragraph'
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
for child_textlines in child_textregion:
|
||||
if child_textlines.tag.endswith("Coords"):
|
||||
cropped_lines_region_indexer.append(indexer_text_region)
|
||||
p_h=child_textlines.attrib['points'].split(' ')
|
||||
textline_coords = np.array( [ [int(x.split(',')[0]),
|
||||
int(x.split(',')[1]) ]
|
||||
for x in p_h] )
|
||||
|
||||
x,y,w,h = cv2.boundingRect(textline_coords)
|
||||
|
||||
angle_radians = math.atan2(h, w)
|
||||
# Convert to degrees
|
||||
angle_degrees = math.degrees(angle_radians)
|
||||
if type_textregion=='drop-capital':
|
||||
angle_degrees = 0
|
||||
|
||||
if dir_out_image_text:
|
||||
total_bb_coordinates.append([x,y,w,h])
|
||||
|
||||
w_scaled = w * image_height/float(h)
|
||||
|
||||
img_poly_on_img = np.copy(img)
|
||||
if dir_in_bin is not None:
|
||||
img_poly_on_img_bin = np.copy(img_bin)
|
||||
img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :]
|
||||
|
||||
mask_poly = np.zeros(img.shape)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
||||
|
||||
|
||||
mask_poly = mask_poly[y:y+h, x:x+w, :]
|
||||
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
|
||||
|
||||
if self.export_textline_images_and_text:
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop[mask_poly==0] = 255
|
||||
|
||||
else:
|
||||
# print(file_name, angle_degrees, w*h,
|
||||
# mask_poly[:,:,0].sum(),
|
||||
# mask_poly[:,:,0].sum() /float(w*h) ,
|
||||
# 'didi')
|
||||
|
||||
if angle_degrees > 3:
|
||||
better_des_slope = get_orientation_moments(textline_coords)
|
||||
|
||||
img_crop = rotate_image_with_padding(img_crop, better_des_slope)
|
||||
if dir_in_bin is not None:
|
||||
img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope)
|
||||
|
||||
mask_poly = rotate_image_with_padding(mask_poly, better_des_slope)
|
||||
mask_poly = mask_poly.astype('uint8')
|
||||
|
||||
#new bounding box
|
||||
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0])
|
||||
|
||||
mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop[mask_poly==0] = 255
|
||||
if dir_in_bin is not None:
|
||||
img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :]
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop_bin[mask_poly==0] = 255
|
||||
|
||||
if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90:
|
||||
if dir_in_bin is not None:
|
||||
img_crop, img_crop_bin = \
|
||||
break_curved_line_into_small_pieces_and_then_merge(
|
||||
img_crop, mask_poly, img_crop_bin)
|
||||
else:
|
||||
img_crop, _ = \
|
||||
break_curved_line_into_small_pieces_and_then_merge(
|
||||
img_crop, mask_poly)
|
||||
|
||||
else:
|
||||
better_des_slope = 0
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop[mask_poly==0] = 255
|
||||
if dir_in_bin is not None:
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
img_crop_bin[mask_poly==0] = 255
|
||||
if type_textregion=='drop-capital':
|
||||
pass
|
||||
else:
|
||||
if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90:
|
||||
if dir_in_bin is not None:
|
||||
img_crop, img_crop_bin = \
|
||||
break_curved_line_into_small_pieces_and_then_merge(
|
||||
img_crop, mask_poly, img_crop_bin)
|
||||
else:
|
||||
img_crop, _ = \
|
||||
break_curved_line_into_small_pieces_and_then_merge(
|
||||
img_crop, mask_poly)
|
||||
|
||||
if not self.export_textline_images_and_text:
|
||||
if w_scaled < 750:#1.5*image_width:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
img_crop, image_height, image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
if abs(better_des_slope) > 45:
|
||||
cropped_lines_ver_index.append(1)
|
||||
else:
|
||||
cropped_lines_ver_index.append(0)
|
||||
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
if dir_in_bin is not None:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
img_crop_bin, image_height, image_width)
|
||||
cropped_lines_bin.append(img_fin)
|
||||
else:
|
||||
splited_images, splited_images_bin = return_textlines_split_if_needed(
|
||||
img_crop, img_crop_bin if dir_in_bin is not None else None)
|
||||
if splited_images:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
splited_images[0], image_height, image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(1)
|
||||
|
||||
if abs(better_des_slope) > 45:
|
||||
cropped_lines_ver_index.append(1)
|
||||
else:
|
||||
cropped_lines_ver_index.append(0)
|
||||
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
splited_images[1], image_height, image_width)
|
||||
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(-1)
|
||||
|
||||
if abs(better_des_slope) > 45:
|
||||
cropped_lines_ver_index.append(1)
|
||||
else:
|
||||
cropped_lines_ver_index.append(0)
|
||||
|
||||
if dir_in_bin is not None:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
splited_images_bin[0], image_height, image_width)
|
||||
cropped_lines_bin.append(img_fin)
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
splited_images_bin[1], image_height, image_width)
|
||||
cropped_lines_bin.append(img_fin)
|
||||
|
||||
else:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
img_crop, image_height, image_width)
|
||||
cropped_lines.append(img_fin)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
|
||||
if abs(better_des_slope) > 45:
|
||||
cropped_lines_ver_index.append(1)
|
||||
else:
|
||||
cropped_lines_ver_index.append(0)
|
||||
|
||||
if dir_in_bin is not None:
|
||||
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
|
||||
img_crop_bin, image_height, image_width)
|
||||
cropped_lines_bin.append(img_fin)
|
||||
|
||||
if self.export_textline_images_and_text:
|
||||
if img_crop.shape[0]==0 or img_crop.shape[1]==0:
|
||||
pass
|
||||
else:
|
||||
if child_textlines.tag.endswith("TextEquiv"):
|
||||
for cheild_text in child_textlines:
|
||||
if cheild_text.tag.endswith("Unicode"):
|
||||
textline_text = cheild_text.text
|
||||
if textline_text:
|
||||
base_name = os.path.join(
|
||||
dir_out, file_name + '_line_' + str(indexer_textlines))
|
||||
if self.pref_of_dataset:
|
||||
base_name += '_' + self.pref_of_dataset
|
||||
if not self.do_not_mask_with_textline_contour:
|
||||
base_name += '_masked'
|
||||
|
||||
with open(base_name + '.txt', 'w') as text_file:
|
||||
text_file.write(textline_text)
|
||||
cv2.imwrite(base_name + '.png', img_crop)
|
||||
indexer_textlines+=1
|
||||
|
||||
if not self.export_textline_images_and_text:
|
||||
indexer_text_region = indexer_text_region +1
|
||||
|
||||
if not self.export_textline_images_and_text:
|
||||
extracted_texts = []
|
||||
extracted_conf_value = []
|
||||
|
||||
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
||||
|
||||
for i in range(n_iterations):
|
||||
if i==(n_iterations-1):
|
||||
n_start = i*self.b_s
|
||||
imgs = cropped_lines[n_start:]
|
||||
imgs = np.array(imgs)
|
||||
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
|
||||
|
||||
ver_imgs = np.array( cropped_lines_ver_index[n_start:] )
|
||||
indices_ver = np.where(ver_imgs == 1)[0]
|
||||
|
||||
#print(indices_ver, 'indices_ver')
|
||||
if len(indices_ver)>0:
|
||||
imgs_ver_flipped = imgs[indices_ver, : ,: ,:]
|
||||
imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:]
|
||||
#print(imgs_ver_flipped, 'imgs_ver_flipped')
|
||||
|
||||
else:
|
||||
imgs_ver_flipped = None
|
||||
|
||||
if dir_in_bin is not None:
|
||||
imgs_bin = cropped_lines_bin[n_start:]
|
||||
imgs_bin = np.array(imgs_bin)
|
||||
imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3)
|
||||
|
||||
if len(indices_ver)>0:
|
||||
imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:]
|
||||
imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:]
|
||||
#print(imgs_ver_flipped, 'imgs_ver_flipped')
|
||||
|
||||
else:
|
||||
imgs_bin_ver_flipped = None
|
||||
else:
|
||||
n_start = i*self.b_s
|
||||
n_end = (i+1)*self.b_s
|
||||
imgs = cropped_lines[n_start:n_end]
|
||||
imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3)
|
||||
|
||||
ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] )
|
||||
indices_ver = np.where(ver_imgs == 1)[0]
|
||||
#print(indices_ver, 'indices_ver')
|
||||
|
||||
if len(indices_ver)>0:
|
||||
imgs_ver_flipped = imgs[indices_ver, : ,: ,:]
|
||||
imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:]
|
||||
#print(imgs_ver_flipped, 'imgs_ver_flipped')
|
||||
else:
|
||||
imgs_ver_flipped = None
|
||||
|
||||
|
||||
if dir_in_bin is not None:
|
||||
imgs_bin = cropped_lines_bin[n_start:n_end]
|
||||
imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3)
|
||||
|
||||
|
||||
if len(indices_ver)>0:
|
||||
imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:]
|
||||
imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:]
|
||||
#print(imgs_ver_flipped, 'imgs_ver_flipped')
|
||||
else:
|
||||
imgs_bin_ver_flipped = None
|
||||
|
||||
|
||||
self.logger.debug("processing next %d lines", len(imgs))
|
||||
preds = self.model_zoo.get('ocr').predict(imgs, verbose=0)
|
||||
|
||||
if len(indices_ver)>0:
|
||||
preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0)
|
||||
preds_max_fliped = np.max(preds_flipped, axis=2 )
|
||||
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
|
||||
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
|
||||
masked_means_flipped = \
|
||||
np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool_flipped, axis=1)
|
||||
masked_means_flipped[np.isnan(masked_means_flipped)] = 0
|
||||
|
||||
preds_max = np.max(preds, axis=2 )
|
||||
preds_max_args = np.argmax(preds, axis=2 )
|
||||
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
|
||||
|
||||
masked_means = \
|
||||
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool, axis=1)
|
||||
masked_means[np.isnan(masked_means)] = 0
|
||||
|
||||
masked_means_ver = masked_means[indices_ver]
|
||||
#print(masked_means_ver, 'pred_max_not_unk')
|
||||
|
||||
indices_where_flipped_conf_value_is_higher = \
|
||||
np.where(masked_means_flipped > masked_means_ver)[0]
|
||||
|
||||
#print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher')
|
||||
if len(indices_where_flipped_conf_value_is_higher)>0:
|
||||
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
|
||||
preds[indices_to_be_replaced,:,:] = \
|
||||
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
|
||||
if dir_in_bin is not None:
|
||||
preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0)
|
||||
|
||||
if len(indices_ver)>0:
|
||||
preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0)
|
||||
preds_max_fliped = np.max(preds_flipped, axis=2 )
|
||||
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
|
||||
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
|
||||
masked_means_flipped = \
|
||||
np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool_flipped, axis=1)
|
||||
masked_means_flipped[np.isnan(masked_means_flipped)] = 0
|
||||
|
||||
preds_max = np.max(preds, axis=2 )
|
||||
preds_max_args = np.argmax(preds, axis=2 )
|
||||
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
|
||||
|
||||
masked_means = \
|
||||
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool, axis=1)
|
||||
masked_means[np.isnan(masked_means)] = 0
|
||||
|
||||
masked_means_ver = masked_means[indices_ver]
|
||||
#print(masked_means_ver, 'pred_max_not_unk')
|
||||
|
||||
indices_where_flipped_conf_value_is_higher = \
|
||||
np.where(masked_means_flipped > masked_means_ver)[0]
|
||||
|
||||
#print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher')
|
||||
if len(indices_where_flipped_conf_value_is_higher)>0:
|
||||
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
|
||||
preds_bin[indices_to_be_replaced,:,:] = \
|
||||
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
|
||||
|
||||
preds = (preds + preds_bin) / 2.
|
||||
|
||||
pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char'))
|
||||
|
||||
preds_max = np.max(preds, axis=2 )
|
||||
preds_max_args = np.argmax(preds, axis=2 )
|
||||
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
|
||||
masked_means = \
|
||||
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
|
||||
np.sum(pred_max_not_unk_mask_bool, axis=1)
|
||||
|
||||
for ib in range(imgs.shape[0]):
|
||||
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
|
||||
if masked_means[ib] >= self.min_conf_value_of_textline_text:
|
||||
extracted_texts.append(pred_texts_ib)
|
||||
extracted_conf_value.append(masked_means[ib])
|
||||
else:
|
||||
extracted_texts.append("")
|
||||
extracted_conf_value.append(0)
|
||||
del cropped_lines
|
||||
if dir_in_bin is not None:
|
||||
del cropped_lines_bin
|
||||
gc.collect()
|
||||
|
||||
extracted_texts_merged = [extracted_texts[ind]
|
||||
if cropped_lines_meging_indexing[ind]==0
|
||||
else extracted_texts[ind]+" "+extracted_texts[ind+1]
|
||||
if cropped_lines_meging_indexing[ind]==1
|
||||
else None
|
||||
for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_conf_value_merged = [extracted_conf_value[ind]
|
||||
if cropped_lines_meging_indexing[ind]==0
|
||||
else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2.
|
||||
if cropped_lines_meging_indexing[ind]==1
|
||||
else None
|
||||
for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_conf_value_merged = [extracted_conf_value_merged[ind_cfm]
|
||||
for ind_cfm in range(len(extracted_texts_merged))
|
||||
if extracted_texts_merged[ind_cfm] is not None]
|
||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
||||
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
|
||||
|
||||
if dir_out_image_text:
|
||||
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||
font = importlib_resources.files(__package__) / "Charis-Regular.ttf"
|
||||
with importlib_resources.as_file(font) as font:
|
||||
font = ImageFont.truetype(font=font, size=40)
|
||||
|
||||
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
|
||||
x_bb = bb_ind[0]
|
||||
y_bb = bb_ind[1]
|
||||
w_bb = bb_ind[2]
|
||||
h_bb = bb_ind[3]
|
||||
|
||||
font = fit_text_single_line(draw, extracted_texts_merged[indexer_text],
|
||||
font.path, w_bb, int(h_bb*0.4) )
|
||||
|
||||
##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2)
|
||||
|
||||
text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font)
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
|
||||
text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally
|
||||
text_y = y_bb + (h_bb - text_height) // 2 # Center vertically
|
||||
|
||||
# Draw the text
|
||||
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
|
||||
image_text.save(out_image_with_text)
|
||||
|
||||
text_by_textregion = []
|
||||
for ind in unique_cropped_lines_region_indexer:
|
||||
ind = np.array(cropped_lines_region_indexer)==ind
|
||||
extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
|
||||
if len(extracted_texts_merged_un)>1:
|
||||
text_by_textregion_ind = ""
|
||||
next_glue = ""
|
||||
for indt in range(len(extracted_texts_merged_un)):
|
||||
if (extracted_texts_merged_un[indt].endswith('⸗') or
|
||||
extracted_texts_merged_un[indt].endswith('-') or
|
||||
extracted_texts_merged_un[indt].endswith('¬')):
|
||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1]
|
||||
next_glue = ""
|
||||
else:
|
||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt]
|
||||
next_glue = " "
|
||||
text_by_textregion.append(text_by_textregion_ind)
|
||||
else:
|
||||
text_by_textregion.append(" ".join(extracted_texts_merged_un))
|
||||
#print(text_by_textregion, 'text_by_textregiontext_by_textregiontext_by_textregiontext_by_textregiontext_by_textregion')
|
||||
|
||||
###index_tot_regions = []
|
||||
###tot_region_ref = []
|
||||
|
||||
###for jj in root1.iter(link+'RegionRefIndexed'):
|
||||
###index_tot_regions.append(jj.attrib['index'])
|
||||
###tot_region_ref.append(jj.attrib['regionRef'])
|
||||
|
||||
###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)}
|
||||
|
||||
#id_textregions = []
|
||||
#textregions_by_existing_ids = []
|
||||
indexer = 0
|
||||
indexer_textregion = 0
|
||||
for nn in root1.iter(region_tags):
|
||||
#id_textregion = nn.attrib['id']
|
||||
#id_textregions.append(id_textregion)
|
||||
#textregions_by_existing_ids.append(text_by_textregion[indexer_textregion])
|
||||
|
||||
is_textregion_text = False
|
||||
for childtest in nn:
|
||||
if childtest.tag.endswith("TextEquiv"):
|
||||
is_textregion_text = True
|
||||
|
||||
if not is_textregion_text:
|
||||
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
|
||||
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
|
||||
|
||||
|
||||
has_textline = False
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
|
||||
is_textline_text = False
|
||||
for childtest2 in child_textregion:
|
||||
if childtest2.tag.endswith("TextEquiv"):
|
||||
is_textline_text = True
|
||||
|
||||
|
||||
if not is_textline_text:
|
||||
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
|
||||
text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
|
||||
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
|
||||
unicode_textline.text = extracted_texts_merged[indexer]
|
||||
else:
|
||||
for childtest3 in child_textregion:
|
||||
if childtest3.tag.endswith("TextEquiv"):
|
||||
for child_uc in childtest3:
|
||||
if child_uc.tag.endswith("Unicode"):
|
||||
childtest3.set('conf',
|
||||
f"{extracted_conf_value_merged[indexer]:.2f}")
|
||||
child_uc.text = extracted_texts_merged[indexer]
|
||||
|
||||
indexer = indexer + 1
|
||||
has_textline = True
|
||||
if has_textline:
|
||||
if is_textregion_text:
|
||||
for child4 in nn:
|
||||
if child4.tag.endswith("TextEquiv"):
|
||||
for childtr_uc in child4:
|
||||
if childtr_uc.tag.endswith("Unicode"):
|
||||
childtr_uc.text = text_by_textregion[indexer_textregion]
|
||||
else:
|
||||
unicode_textregion.text = text_by_textregion[indexer_textregion]
|
||||
indexer_textregion = indexer_textregion + 1
|
||||
|
||||
###sample_order = [(id_to_order[tid], text)
|
||||
### for tid, text in zip(id_textregions, textregions_by_existing_ids)
|
||||
### if tid in id_to_order]
|
||||
|
||||
##ordered_texts_sample = [text for _, text in sorted(sample_order)]
|
||||
##tot_page_text = ' '.join(ordered_texts_sample)
|
||||
|
||||
##for page_element in root1.iter(link+'Page'):
|
||||
##text_page = ET.SubElement(page_element, 'TextEquiv')
|
||||
##unicode_textpage = ET.SubElement(text_page, 'Unicode')
|
||||
##unicode_textpage.text = tot_page_text
|
||||
|
||||
ET.register_namespace("",name_space)
|
||||
tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None)
|
||||
#print("Job done in %.1fs", time.time() - t0)
|
||||
|
|
@ -5,24 +5,25 @@ Image enhancer. The output can be written as same scale of input or in new predi
|
|||
from logging import Logger
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
from pathlib import Path
|
||||
import gc
|
||||
|
||||
import cv2
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
||||
import tensorflow as tf
|
||||
from skimage.morphology import skeletonize
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.resize import resize_image
|
||||
from .utils.pil_cv2 import pil2cv
|
||||
from .utils import (
|
||||
is_image_filename,
|
||||
crop_image_inside_box
|
||||
)
|
||||
from .eynollah import PatchEncoder, Patches
|
||||
from .patch_encoder import PatchEncoder, Patches
|
||||
|
||||
DPI_THRESHOLD = 298
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
|
@ -50,11 +51,9 @@ class Enhancer:
|
|||
self.num_col_lower = num_col_lower
|
||||
|
||||
self.logger = logger if logger else getLogger('enhancement')
|
||||
self.dir_models = dir_models
|
||||
self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425"
|
||||
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
|
||||
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
|
||||
self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915"
|
||||
self.model_zoo = EynollahModelZoo(basedir=dir_models)
|
||||
for v in ['binarization', 'enhancement', 'col_classifier', 'page']:
|
||||
self.model_zoo.load_model(v)
|
||||
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
|
|
@ -62,11 +61,6 @@ class Enhancer:
|
|||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
self.model_page = self.our_load_model(self.model_page_dir)
|
||||
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
|
||||
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
|
||||
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
|
||||
|
||||
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
|
||||
ret = {}
|
||||
if image_filename:
|
||||
|
|
@ -103,23 +97,11 @@ class Enhancer:
|
|||
def isNaN(self, num):
|
||||
return num != num
|
||||
|
||||
@staticmethod
|
||||
def our_load_model(model_file):
|
||||
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_file = model_file[:-3]
|
||||
try:
|
||||
model = load_model(model_file, compile=False)
|
||||
except:
|
||||
model = load_model(model_file, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
return model
|
||||
|
||||
def predict_enhancement(self, img):
|
||||
self.logger.debug("enter predict_enhancement")
|
||||
|
||||
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
|
||||
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
|
||||
img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1]
|
||||
img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2]
|
||||
if img.shape[0] < img_height_model:
|
||||
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
|
||||
if img.shape[1] < img_width_model:
|
||||
|
|
@ -160,7 +142,7 @@ class Enhancer:
|
|||
index_y_d = img_h - img_height_model
|
||||
|
||||
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
|
||||
label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose=0)
|
||||
seg = label_p_pred[0, :, :, :] * 255
|
||||
|
||||
if i == 0 and j == 0:
|
||||
|
|
@ -246,7 +228,7 @@ class Enhancer:
|
|||
else:
|
||||
img = self.imread()
|
||||
img = cv2.GaussianBlur(img, (5, 5), 0)
|
||||
img_page_prediction = self.do_prediction(False, img, self.model_page)
|
||||
img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page'))
|
||||
|
||||
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
|
||||
_, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
|
|
@ -291,7 +273,7 @@ class Enhancer:
|
|||
self.logger.info("Detected %s DPI", dpi)
|
||||
if self.input_binary:
|
||||
img = self.imread()
|
||||
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5)
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
|
||||
img= np.copy(prediction_bin)
|
||||
|
|
@ -332,7 +314,7 @@ class Enhancer:
|
|||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||
|
||||
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||
label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
|
||||
num_col = np.argmax(label_p_pred[0]) + 1
|
||||
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
|
||||
if self.input_binary:
|
||||
|
|
@ -352,7 +334,7 @@ class Enhancer:
|
|||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||
|
||||
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||
label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
|
||||
num_col = np.argmax(label_p_pred[0]) + 1
|
||||
|
||||
if num_col > self.num_col_upper:
|
||||
|
|
|
|||
|
|
@ -10,12 +10,13 @@ from pathlib import Path
|
|||
import xml.etree.ElementTree as ET
|
||||
|
||||
import cv2
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
from ocrd_utils import getLogger
|
||||
import statistics
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.resize import resize_image
|
||||
from .utils.contour import (
|
||||
find_new_features_of_contours,
|
||||
|
|
@ -23,7 +24,6 @@ from .utils.contour import (
|
|||
return_parent_contours,
|
||||
)
|
||||
from .utils import is_xml_filename
|
||||
from .eynollah import PatchEncoder, Patches
|
||||
|
||||
DPI_THRESHOLD = 298
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
|
@ -45,21 +45,11 @@ class machine_based_reading_order_on_layout:
|
|||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
|
||||
self.model_zoo = EynollahModelZoo(basedir=dir_models)
|
||||
self.model_zoo.load_model('reading_order')
|
||||
# FIXME: light_version is always true, no need for checks in the code
|
||||
self.light_version = True
|
||||
|
||||
@staticmethod
|
||||
def our_load_model(model_file):
|
||||
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_file = model_file[:-3]
|
||||
try:
|
||||
model = load_model(model_file, compile=False)
|
||||
except:
|
||||
model = load_model(model_file, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
return model
|
||||
|
||||
def read_xml(self, xml_file):
|
||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||
root1=tree1.getroot()
|
||||
|
|
@ -69,6 +59,7 @@ class machine_based_reading_order_on_layout:
|
|||
index_tot_regions = []
|
||||
tot_region_ref = []
|
||||
|
||||
y_len, x_len = 0, 0
|
||||
for jj in root1.iter(link+'Page'):
|
||||
y_len=int(jj.attrib['imageHeight'])
|
||||
x_len=int(jj.attrib['imageWidth'])
|
||||
|
|
@ -81,13 +72,13 @@ class machine_based_reading_order_on_layout:
|
|||
co_printspace = []
|
||||
if link+'PrintSpace' in alltags:
|
||||
region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')])
|
||||
elif link+'Border' in alltags:
|
||||
else:
|
||||
region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')])
|
||||
|
||||
for tag in region_tags_printspace:
|
||||
if link+'PrintSpace' in alltags:
|
||||
tag_endings_printspace = ['}PrintSpace','}printspace']
|
||||
elif link+'Border' in alltags:
|
||||
else:
|
||||
tag_endings_printspace = ['}Border','}border']
|
||||
|
||||
if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]):
|
||||
|
|
@ -683,7 +674,7 @@ class machine_based_reading_order_on_layout:
|
|||
tot_counter += 1
|
||||
batch.append(j)
|
||||
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
||||
y_pr = self.model_reading_order.predict(input_1 , verbose=0)
|
||||
y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0')
|
||||
for jb, j in enumerate(batch):
|
||||
if y_pr[jb][0]>=0.5:
|
||||
post_list.append(j)
|
||||
|
|
@ -802,6 +793,7 @@ class machine_based_reading_order_on_layout:
|
|||
alltags=[elem.tag for elem in root_xml.iter()]
|
||||
|
||||
ET.register_namespace("",name_space)
|
||||
assert dir_out
|
||||
tree_xml.write(os.path.join(dir_out, file_name+'.xml'),
|
||||
xml_declaration=True,
|
||||
method='xml',
|
||||
|
|
|
|||
4
src/eynollah/model_zoo/__init__.py
Normal file
4
src/eynollah/model_zoo/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
__all__ = [
|
||||
'EynollahModelZoo',
|
||||
]
|
||||
from .model_zoo import EynollahModelZoo
|
||||
314
src/eynollah/model_zoo/default_specs.py
Normal file
314
src/eynollah/model_zoo/default_specs.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
from .specs import EynollahModelSpec, EynollahModelSpecSet
|
||||
from .types import KerasModel, TrOCRProcessor, List
|
||||
|
||||
# NOTE: This needs to change whenever models/versions change
|
||||
ZENODO = "https://zenodo.org/records/17295988/files"
|
||||
MODELS_VERSION = "v0_7_0"
|
||||
|
||||
def dist_url(dist_name: str) -> str:
|
||||
return f'{ZENODO}/models_{dist_name}_{MODELS_VERSION}.zip'
|
||||
|
||||
DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
|
||||
|
||||
EynollahModelSpec(
|
||||
category="enhancement",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-enhancement_20210425",
|
||||
dists=['enhancement', 'layout'],
|
||||
dist_url=dist_url("enhancement"),
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="binarization",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-binarization-hybrid_20230504",
|
||||
dists=['layout', 'binarization'],
|
||||
dist_url=dist_url("binarization"),
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="binarization",
|
||||
variant='20210309',
|
||||
filename="models_eynollah/eynollah-binarization_20210309",
|
||||
dists=['binarization'],
|
||||
dist_url=dist_url("binarization"),
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="binarization",
|
||||
variant='augment',
|
||||
filename="models_eynollah/eynollah-binarization_20210425",
|
||||
dists=['binarization'],
|
||||
dist_url=dist_url("binarization"),
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="binarization_multi_1",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin1",
|
||||
dist_url=dist_url("binarization"),
|
||||
dists=['binarization'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="binarization_multi_2",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin2",
|
||||
dist_url=dist_url("binarization"),
|
||||
dists=['binarization'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="binarization_multi_3",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin3",
|
||||
dist_url=dist_url("binarization"),
|
||||
dists=['binarization'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="binarization_multi_4",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-binarization-multi_2020_01_16/model_bin4",
|
||||
dist_url=dist_url("binarization"),
|
||||
dists=['binarization'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="col_classifier",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-column-classifier_20210425",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="page",
|
||||
variant='',
|
||||
filename="models_eynollah/model_eynollah_page_extraction_20250915",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="region",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="region",
|
||||
variant='extract_only_images',
|
||||
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="region",
|
||||
variant='light',
|
||||
filename="models_eynollah/eynollah-main-regions_20220314",
|
||||
dist_url=dist_url("layout"),
|
||||
help="early layout",
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="region_p2",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425",
|
||||
dist_url=dist_url("layout"),
|
||||
help="early layout, non-light, 2nd part",
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="region_1_2",
|
||||
variant='',
|
||||
#filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n",
|
||||
#filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8",
|
||||
#filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18",
|
||||
#filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige",
|
||||
#filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige",
|
||||
filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
help="early layout, light, 1-or-2-column",
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="region_fl_np",
|
||||
variant='',
|
||||
#'filename="models_eynollah/modelens_full_lay_1_3_031124",
|
||||
#'filename="models_eynollah/modelens_full_lay_13__3_19_241024",
|
||||
#'filename="models_eynollah/model_full_lay_13_241024",
|
||||
#'filename="models_eynollah/modelens_full_lay_13_17_231024",
|
||||
#'filename="models_eynollah/modelens_full_lay_1_2_221024",
|
||||
#'filename="models_eynollah/eynollah-full-regions-1column_20210425",
|
||||
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
|
||||
dist_url=dist_url("layout"),
|
||||
help="full layout / no patches",
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
# FIXME: Why is region_fl and region_fl_np the same model?
|
||||
EynollahModelSpec(
|
||||
category="region_fl",
|
||||
variant='',
|
||||
# filename="models_eynollah/eynollah-full-regions-3+column_20210425",
|
||||
# filename="models_eynollah/model_2_full_layout_new_trans",
|
||||
# filename="models_eynollah/modelens_full_lay_1_3_031124",
|
||||
# filename="models_eynollah/modelens_full_lay_13__3_19_241024",
|
||||
# filename="models_eynollah/model_full_lay_13_241024",
|
||||
# filename="models_eynollah/modelens_full_lay_13_17_231024",
|
||||
# filename="models_eynollah/modelens_full_lay_1_2_221024",
|
||||
# filename="models_eynollah/modelens_full_layout_24_till_28",
|
||||
# filename="models_eynollah/model_2_full_layout_new_trans",
|
||||
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
|
||||
dist_url=dist_url("layout"),
|
||||
help="full layout / with patches",
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="reading_order",
|
||||
variant='',
|
||||
#filename="models_eynollah/model_mb_ro_aug_ens_11",
|
||||
#filename="models_eynollah/model_step_3200000_mb_ro",
|
||||
#filename="models_eynollah/model_ens_reading_order_machine_based",
|
||||
#filename="models_eynollah/model_mb_ro_aug_ens_8",
|
||||
#filename="models_eynollah/model_ens_reading_order_machine_based",
|
||||
filename="models_eynollah/model_eynollah_reading_order_20250824",
|
||||
dist_url=dist_url("reading_order"),
|
||||
dists=['layout', 'reading_order'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="textline",
|
||||
variant='',
|
||||
#filename="models_eynollah/modelens_textline_1_4_16092024",
|
||||
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
|
||||
#filename="models_eynollah/modelens_textline_1_3_4_20240915",
|
||||
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
|
||||
#filename="models_eynollah/modelens_textline_9_12_13_14_15",
|
||||
#filename="models_eynollah/eynollah-textline_20210425",
|
||||
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="textline",
|
||||
variant='light',
|
||||
#filename="models_eynollah/eynollah-textline_light_20210425",
|
||||
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="table",
|
||||
variant='',
|
||||
filename="models_eynollah/eynollah-tables_20210319",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="table",
|
||||
variant='light',
|
||||
filename="models_eynollah/modelens_table_0t4_201124",
|
||||
dist_url=dist_url("layout"),
|
||||
dists=['layout'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="ocr",
|
||||
variant='',
|
||||
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
|
||||
dist_url=dist_url("ocr"),
|
||||
dists=['layout', 'ocr'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="ocr",
|
||||
variant='degraded',
|
||||
filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/",
|
||||
help="slightly better at degraded Fraktur",
|
||||
dist_url=dist_url("ocr"),
|
||||
dists=['ocr'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="num_to_char",
|
||||
variant='',
|
||||
filename="characters_org.txt",
|
||||
dist_url=dist_url("ocr"),
|
||||
dists=['ocr'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="characters",
|
||||
variant='',
|
||||
filename="characters_org.txt",
|
||||
dist_url=dist_url("ocr"),
|
||||
dists=['ocr'],
|
||||
type=list,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="ocr",
|
||||
variant='tr',
|
||||
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
|
||||
dist_url=dist_url("trocr"),
|
||||
help='much slower transformer-based',
|
||||
dists=['trocr'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="trocr_processor",
|
||||
variant='',
|
||||
filename="models_eynollah/microsoft/trocr-base-printed",
|
||||
dist_url=dist_url("trocr"),
|
||||
dists=['trocr'],
|
||||
type=KerasModel,
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="trocr_processor",
|
||||
variant='htr',
|
||||
filename="models_eynollah/microsoft/trocr-base-handwritten",
|
||||
dist_url=dist_url("trocr"),
|
||||
dists=['trocr'],
|
||||
type=TrOCRProcessor,
|
||||
),
|
||||
|
||||
])
|
||||
189
src/eynollah/model_zoo/model_zoo.py
Normal file
189
src/eynollah/model_zoo/model_zoo.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from keras.layers import StringLookup
|
||||
from keras.models import Model as KerasModel
|
||||
from keras.models import load_model
|
||||
from tabulate import tabulate
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
from ..patch_encoder import PatchEncoder, Patches
|
||||
from .specs import EynollahModelSpecSet
|
||||
from .default_specs import DEFAULT_MODEL_SPECS
|
||||
from .types import AnyModel, T
|
||||
|
||||
|
||||
class EynollahModelZoo:
|
||||
"""
|
||||
Wrapper class that handles storage and loading of models for all eynollah runners.
|
||||
"""
|
||||
|
||||
model_basedir: Path
|
||||
specs: EynollahModelSpecSet
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
basedir: str,
|
||||
model_overrides: Optional[List[Tuple[str, str, str]]] = None,
|
||||
) -> None:
|
||||
self.model_basedir = Path(basedir)
|
||||
self.logger = logging.getLogger('eynollah.model_zoo')
|
||||
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
|
||||
if model_overrides:
|
||||
self.override_models(*model_overrides)
|
||||
self._loaded: Dict[str, AnyModel] = {}
|
||||
|
||||
def override_models(
|
||||
self,
|
||||
*model_overrides: Tuple[str, str, str],
|
||||
):
|
||||
"""
|
||||
Override the default model versions
|
||||
"""
|
||||
for model_category, model_variant, model_filename in model_overrides:
|
||||
spec = self.specs.get(model_category, model_variant)
|
||||
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
|
||||
self.specs.get(model_category, model_variant).filename = model_filename
|
||||
|
||||
def model_path(
|
||||
self,
|
||||
model_category: str,
|
||||
model_variant: str = '',
|
||||
absolute: bool = True,
|
||||
) -> Path:
|
||||
"""
|
||||
Translate model_{type,variant} tuple into an absolute (or relative) Path
|
||||
"""
|
||||
spec = self.specs.get(model_category, model_variant)
|
||||
if spec.category in ('characters', 'num_to_char'):
|
||||
return self.model_path('ocr') / spec.filename
|
||||
if not Path(spec.filename).is_absolute() and absolute:
|
||||
model_path = Path(self.model_basedir).joinpath(spec.filename)
|
||||
else:
|
||||
model_path = Path(spec.filename)
|
||||
return model_path
|
||||
|
||||
def load_models(
|
||||
self,
|
||||
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
|
||||
) -> Dict:
|
||||
"""
|
||||
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
|
||||
"""
|
||||
ret = {}
|
||||
for load_args in all_load_args:
|
||||
if isinstance(load_args, str):
|
||||
ret[load_args] = self.load_model(load_args)
|
||||
else:
|
||||
ret[load_args[0]] = self.load_model(*load_args)
|
||||
return ret
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_category: str,
|
||||
model_variant: str = '',
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Load any model
|
||||
"""
|
||||
model_path = self.model_path(model_category, model_variant)
|
||||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_path = Path(model_path.stem)
|
||||
if model_category == 'ocr':
|
||||
model = self._load_ocr_model(variant=model_variant)
|
||||
elif model_category == 'num_to_char':
|
||||
model = self._load_num_to_char()
|
||||
elif model_category == 'characters':
|
||||
model = self._load_characters()
|
||||
elif model_category == 'trocr_processor':
|
||||
return TrOCRProcessor.from_pretrained(self.model_path(...))
|
||||
else:
|
||||
try:
|
||||
model = load_model(model_path, compile=False)
|
||||
except Exception as e:
|
||||
self.logger.exception(e)
|
||||
model = load_model(
|
||||
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
|
||||
)
|
||||
self._loaded[model_category] = model
|
||||
return model # type: ignore
|
||||
|
||||
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T:
|
||||
if model_category not in self._loaded:
|
||||
raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"')
|
||||
ret = self._loaded[model_category]
|
||||
if model_type:
|
||||
assert isinstance(ret, model_type)
|
||||
return ret # type: ignore # FIXME: convince typing that we're returning generic type
|
||||
|
||||
def _load_ocr_model(self, variant: str) -> AnyModel:
|
||||
"""
|
||||
Load OCR model
|
||||
"""
|
||||
ocr_model_dir = self.model_path('ocr', variant)
|
||||
if variant == 'tr':
|
||||
return VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
||||
else:
|
||||
ocr_model = load_model(ocr_model_dir, compile=False)
|
||||
assert isinstance(ocr_model, KerasModel)
|
||||
return KerasModel(
|
||||
ocr_model.get_layer(name="image").input, # type: ignore
|
||||
ocr_model.get_layer(name="dense2").output, # type: ignore
|
||||
)
|
||||
|
||||
def _load_characters(self) -> List[str]:
|
||||
"""
|
||||
Load encoding for OCR
|
||||
"""
|
||||
with open(self.model_path('num_to_char'), "r") as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def _load_num_to_char(self) -> StringLookup:
|
||||
"""
|
||||
Load decoder for OCR
|
||||
"""
|
||||
characters = self._load_characters()
|
||||
# Mapping characters to integers.
|
||||
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
|
||||
# Mapping integers back to original characters.
|
||||
return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
|
||||
|
||||
def __str__(self):
|
||||
return tabulate(
|
||||
[
|
||||
[
|
||||
spec.type.__name__,
|
||||
spec.category,
|
||||
spec.variant,
|
||||
spec.help,
|
||||
', '.join(spec.dists),
|
||||
f'Yes, at {self.model_path(spec.category, spec.variant)}'
|
||||
if self.model_path(spec.category, spec.variant).exists()
|
||||
else f'No, download {spec.dist_url}',
|
||||
# self.model_path(spec.category, spec.variant),
|
||||
]
|
||||
for spec in self.specs.specs
|
||||
],
|
||||
headers=[
|
||||
'Type',
|
||||
'Category',
|
||||
'Variant',
|
||||
'Help',
|
||||
'Used in',
|
||||
'Installed',
|
||||
],
|
||||
tablefmt='github',
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
|
||||
"""
|
||||
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||||
for needle in self._loaded:
|
||||
if self._loaded[needle]:
|
||||
del self._loaded[needle]
|
||||
55
src/eynollah/model_zoo/specs.py
Normal file
55
src/eynollah/model_zoo/specs.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Set, Tuple, Type
|
||||
from .types import AnyModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class EynollahModelSpec():
|
||||
"""
|
||||
Describing a single model abstractly.
|
||||
"""
|
||||
category: str
|
||||
# Relative filename to the models_eynollah directory in the dists
|
||||
filename: str
|
||||
# basename of the ZIP files that should contain this model
|
||||
dists: List[str]
|
||||
# URL to the smallest model distribution containing this model (link to Zenodo)
|
||||
dist_url: str
|
||||
type: Type[AnyModel]
|
||||
variant: str = ''
|
||||
help: str = ''
|
||||
|
||||
class EynollahModelSpecSet():
|
||||
"""
|
||||
List of all used models for eynollah.
|
||||
"""
|
||||
specs: List[EynollahModelSpec]
|
||||
|
||||
def __init__(self, specs: List[EynollahModelSpec]) -> None:
|
||||
self.specs = sorted(specs, key=lambda x: x.category + '0' + x.variant)
|
||||
self.categories: Set[str] = set([spec.category for spec in self.specs])
|
||||
self.variants: Dict[str, Set[str]] = {
|
||||
spec.category: set([x.variant for x in self.specs if x.category == spec.category])
|
||||
for spec in self.specs
|
||||
}
|
||||
self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = {
|
||||
(spec.category, spec.variant): spec
|
||||
for spec in self.specs
|
||||
}
|
||||
|
||||
def asdict(self) -> Dict[str, Dict[str, str]]:
|
||||
return {
|
||||
spec.category: {
|
||||
spec.variant: spec.filename
|
||||
}
|
||||
for spec in self.specs
|
||||
}
|
||||
|
||||
def get(self, category: str, variant: str) -> EynollahModelSpec:
|
||||
if category not in self.categories:
|
||||
raise ValueError(f"Unknown category '{category}', must be one of {self.categories}")
|
||||
if variant not in self.variants[category]:
|
||||
raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}")
|
||||
return self._index_category_variant[(category, variant)]
|
||||
|
||||
|
||||
6
src/eynollah/model_zoo/types.py
Normal file
6
src/eynollah/model_zoo/types.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from typing import List, TypeVar, Union
|
||||
from keras.models import Model as KerasModel
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List]
|
||||
T = TypeVar('T')
|
||||
|
|
@ -83,10 +83,10 @@
|
|||
},
|
||||
"resources": [
|
||||
{
|
||||
"url": "https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1",
|
||||
"name": "models_layout_v0_5_0",
|
||||
"url": "https://zenodo.org/records/17295988/files/models_layout_v0_6_0.tar.gz?download=1",
|
||||
"name": "models_layout_v0_6_0",
|
||||
"type": "archive",
|
||||
"path_in_archive": "models_layout_v0_5_0",
|
||||
"path_in_archive": "models_layout_v0_6_0",
|
||||
"size": 3525684179,
|
||||
"description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement",
|
||||
"version_range": ">= v0.5.0"
|
||||
|
|
|
|||
52
src/eynollah/patch_encoder.py
Normal file
52
src/eynollah/patch_encoder.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from keras import layers
|
||||
import tensorflow as tf
|
||||
|
||||
projection_dim = 64
|
||||
patch_size = 1
|
||||
num_patches =21*21#14*14#28*28#14*14#28*28
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
|
@ -40,8 +40,8 @@ class EynollahPlotter:
|
|||
self.image_filename_stem = image_filename_stem
|
||||
# XXX TODO hacky these cannot be set at init time
|
||||
self.image_org = image_org
|
||||
self.scale_x = scale_x
|
||||
self.scale_y = scale_y
|
||||
self.scale_x : float = scale_x
|
||||
self.scale_y : float = scale_y
|
||||
|
||||
def save_plot_of_layout_main(self, text_regions_p, image_page):
|
||||
if self.dir_of_layout is not None:
|
||||
|
|
|
|||
|
|
@ -2,18 +2,19 @@
|
|||
Tool to load model and binarize a given image.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from glob import glob
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
|
||||
from eynollah.model_zoo import EynollahModelZoo
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
|
||||
from .utils import is_image_filename
|
||||
|
|
@ -23,40 +24,37 @@ def resize_image(img_in, input_height, input_width):
|
|||
|
||||
class SbbBinarizer:
|
||||
|
||||
def __init__(self, model_dir, logger=None):
|
||||
self.model_dir = model_dir
|
||||
def __init__(self, model_dir: str, mode: str, logger=None):
|
||||
if mode not in ('single', 'multi'):
|
||||
raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}")
|
||||
self.log = logger if logger else logging.getLogger('SbbBinarizer')
|
||||
|
||||
self.start_new_session()
|
||||
|
||||
self.model_files = glob(self.model_dir+"/*/", recursive = True)
|
||||
|
||||
self.models = []
|
||||
for model_file in self.model_files:
|
||||
self.models.append(self.load_model(model_file))
|
||||
self.model_zoo = EynollahModelZoo(basedir=model_dir)
|
||||
self.models = self.setup_models(mode)
|
||||
self.session = self.start_new_session()
|
||||
|
||||
def start_new_session(self):
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(self.session)
|
||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(session)
|
||||
return session
|
||||
|
||||
def setup_models(self, mode: str) -> Dict[Path, Model]:
|
||||
return {
|
||||
self.model_zoo.model_path(v): self.model_zoo.load_model(v)
|
||||
for v in (['binarization'] if mode == 'single' else [f'binarization_multi_{i}' for i in range(1, 5)])
|
||||
}
|
||||
|
||||
def end_session(self):
|
||||
tensorflow_backend.clear_session()
|
||||
self.session.close()
|
||||
del self.session
|
||||
|
||||
def load_model(self, model_name):
|
||||
model = load_model(os.path.join(self.model_dir, model_name), compile=False)
|
||||
def predict(self, img, use_patches, n_batch_inference=5):
|
||||
model = self.model_zoo.get('binarization', Model)
|
||||
model_height = model.layers[len(model.layers)-1].output_shape[1]
|
||||
model_width = model.layers[len(model.layers)-1].output_shape[2]
|
||||
n_classes = model.layers[len(model.layers)-1].output_shape[3]
|
||||
return model, model_height, model_width, n_classes
|
||||
|
||||
def predict(self, model_in, img, use_patches, n_batch_inference=5):
|
||||
tensorflow_backend.set_session(self.session)
|
||||
model, model_height, model_width, n_classes = model_in
|
||||
|
||||
img_org_h = img.shape[0]
|
||||
img_org_w = img.shape[1]
|
||||
|
|
@ -324,8 +322,8 @@ class SbbBinarizer:
|
|||
if image_path is not None:
|
||||
image = cv2.imread(image_path)
|
||||
img_last = 0
|
||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
||||
for n, (model_file, model) in enumerate(self.models.items()):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
|
||||
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
|
|
@ -354,8 +352,8 @@ class SbbBinarizer:
|
|||
print(image_name,'image_name')
|
||||
image = cv2.imread(os.path.join(dir_in,image_name) )
|
||||
img_last = 0
|
||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
||||
for n, (model_file, model) in enumerate(self.models.items()):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
|
||||
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
|
|
|
|||
|
|
@ -393,7 +393,12 @@ def find_num_col_deskew(regions_without_separators, sigma_, multiplier=3.8):
|
|||
z = gaussian_filter1d(regions_without_separators_0, sigma_)
|
||||
return np.std(z)
|
||||
|
||||
def find_num_col(regions_without_separators, num_col_classifier, tables, multiplier=3.8):
|
||||
def find_num_col(
|
||||
regions_without_separators,
|
||||
num_col_classifier,
|
||||
tables,
|
||||
multiplier=3.8,
|
||||
):
|
||||
if not regions_without_separators.any():
|
||||
return 0, []
|
||||
#plt.imshow(regions_without_separators)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# pylint: disable=import-error
|
||||
from pathlib import Path
|
||||
import os.path
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Optional
|
||||
from .utils.xml import create_page_xml, xml_reading_order
|
||||
from .utils.counter import EynollahIdCounter
|
||||
|
||||
|
|
@ -10,7 +10,6 @@ from ocrd_utils import getLogger
|
|||
from ocrd_models.ocrd_page import (
|
||||
BorderType,
|
||||
CoordsType,
|
||||
PcGtsType,
|
||||
TextLineType,
|
||||
TextEquivType,
|
||||
TextRegionType,
|
||||
|
|
@ -32,10 +31,10 @@ class EynollahXmlWriter:
|
|||
self.curved_line = curved_line
|
||||
self.textline_light = textline_light
|
||||
self.pcgts = pcgts
|
||||
self.scale_x = None # XXX set outside __init__
|
||||
self.scale_y = None # XXX set outside __init__
|
||||
self.height_org = None # XXX set outside __init__
|
||||
self.width_org = None # XXX set outside __init__
|
||||
self.scale_x: Optional[float] = None # XXX set outside __init__
|
||||
self.scale_y: Optional[float] = None # XXX set outside __init__
|
||||
self.height_org: Optional[int] = None # XXX set outside __init__
|
||||
self.width_org: Optional[int] = None # XXX set outside __init__
|
||||
|
||||
@property
|
||||
def image_filename_stem(self):
|
||||
|
|
@ -135,6 +134,7 @@ class EynollahXmlWriter:
|
|||
# create the file structure
|
||||
pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org)
|
||||
page = pcgts.get_Page()
|
||||
assert page
|
||||
page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_page_coords(cont_page))))
|
||||
|
||||
counter = EynollahIdCounter()
|
||||
|
|
@ -152,6 +152,7 @@ class EynollahXmlWriter:
|
|||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord,
|
||||
skip_layout_reading_order))
|
||||
)
|
||||
assert textregion.Coords
|
||||
if conf_contours_textregions:
|
||||
textregion.Coords.set_conf(conf_contours_textregions[mm])
|
||||
page.add_TextRegion(textregion)
|
||||
|
|
@ -168,6 +169,7 @@ class EynollahXmlWriter:
|
|||
id=counter.next_region_id, type_='heading',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))
|
||||
)
|
||||
assert textregion.Coords
|
||||
if conf_contours_textregions_h:
|
||||
textregion.Coords.set_conf(conf_contours_textregions_h[mm])
|
||||
page.add_TextRegion(textregion)
|
||||
|
|
|
|||
|
|
@ -16,10 +16,13 @@ from ocrd_models.constants import NAMESPACES as NS
|
|||
|
||||
testdir = Path(__file__).parent.resolve()
|
||||
|
||||
MODELS_LAYOUT = environ.get('MODELS_LAYOUT', str(testdir.joinpath('..', 'models_layout_v0_5_0').resolve()))
|
||||
MODELS_OCR = environ.get('MODELS_OCR', str(testdir.joinpath('..', 'models_ocr_v0_5_1').resolve()))
|
||||
MODELS_LAYOUT = environ.get('MODELS_LAYOUT', str(testdir.joinpath('..', 'models_layout_v0_6_0').resolve()))
|
||||
MODELS_OCR = environ.get('MODELS_OCR', str(testdir.joinpath('..', 'models_ocr_v0_6_0').resolve()))
|
||||
MODELS_BIN = environ.get('MODELS_BIN', str(testdir.joinpath('..', 'default-2021-03-09').resolve()))
|
||||
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name.startswith('eynollah')
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"options",
|
||||
[
|
||||
|
|
@ -50,8 +53,6 @@ def test_run_eynollah_layout_filename(tmp_path, pytestconfig, caplog, options):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'eynollah'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(layout_cli, args + options, catch_exceptions=False)
|
||||
|
|
@ -85,8 +86,6 @@ def test_run_eynollah_layout_filename2(tmp_path, pytestconfig, caplog, options):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'eynollah'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(layout_cli, args + options, catch_exceptions=False)
|
||||
|
|
@ -116,8 +115,6 @@ def test_run_eynollah_layout_directory(tmp_path, pytestconfig, caplog):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'eynollah'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(layout_cli, args, catch_exceptions=False)
|
||||
|
|
@ -144,8 +141,6 @@ def test_run_eynollah_binarization_filename(tmp_path, pytestconfig, caplog, opti
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'SbbBinarizer'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(binarization_cli, args + options, catch_exceptions=False)
|
||||
|
|
@ -170,8 +165,6 @@ def test_run_eynollah_binarization_directory(tmp_path, pytestconfig, caplog):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'SbbBinarizer'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(binarization_cli, args, catch_exceptions=False)
|
||||
|
|
@ -197,8 +190,6 @@ def test_run_eynollah_enhancement_filename(tmp_path, pytestconfig, caplog, optio
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'enhancement'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(enhancement_cli, args + options, catch_exceptions=False)
|
||||
|
|
@ -223,8 +214,6 @@ def test_run_eynollah_enhancement_directory(tmp_path, pytestconfig, caplog):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'enhancement'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(enhancement_cli, args, catch_exceptions=False)
|
||||
|
|
@ -244,8 +233,6 @@ def test_run_eynollah_mbreorder_filename(tmp_path, pytestconfig, caplog):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'mbreorder'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(mbreorder_cli, args, catch_exceptions=False)
|
||||
|
|
@ -273,8 +260,6 @@ def test_run_eynollah_mbreorder_directory(tmp_path, pytestconfig, caplog):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'mbreorder'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(mbreorder_cli, args, catch_exceptions=False)
|
||||
|
|
@ -306,8 +291,6 @@ def test_run_eynollah_ocr_filename(tmp_path, pytestconfig, caplog, options):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.DEBUG)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'eynollah'
|
||||
runner = CliRunner()
|
||||
if "-doit" in options:
|
||||
options.insert(options.index("-doit") + 1, str(outrenderfile.parent))
|
||||
|
|
@ -339,8 +322,6 @@ def test_run_eynollah_ocr_directory(tmp_path, pytestconfig, caplog):
|
|||
if pytestconfig.getoption('verbose') > 0:
|
||||
args.extend(['-l', 'DEBUG'])
|
||||
caplog.set_level(logging.INFO)
|
||||
def only_eynollah(logrec):
|
||||
return logrec.name == 'eynollah'
|
||||
runner = CliRunner()
|
||||
with caplog.filtering(only_eynollah):
|
||||
result = runner.invoke(ocr_cli, args, catch_exceptions=False)
|
||||
|
|
|
|||
|
|
@ -22,14 +22,14 @@ Download our pretrained weights and add them to a `train/pretrained_model` folde
|
|||
|
||||
```sh
|
||||
cd train
|
||||
wget -O pretrained_model.tar.gz https://zenodo.org/records/17243320/files/pretrained_model_v0_5_1.tar.gz?download=1
|
||||
wget -O pretrained_model.tar.gz "https://zenodo.org/records/17295988/files/pretrained_model_v0_6_0.tar.gz?download=1"
|
||||
tar xf pretrained_model.tar.gz
|
||||
```
|
||||
|
||||
### Binarization training data
|
||||
|
||||
A small sample of training data for binarization experiment can be found [on
|
||||
zenodo](https://zenodo.org/records/17243320/files/training_data_sample_binarization_v0_5_1.tar.gz?download=1),
|
||||
zenodo](https://zenodo.org/records/17295988/files/training_data_sample_binarization_v0_6_0.tar.gz?download=1),
|
||||
which contains `images` and `labels` folders.
|
||||
|
||||
### Helpful tools
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue