Merge branch 'main' into ro-fixes and resolve conflicts…

major conflicts resolved manually:

- branches for non-`light` segmentation already removed in main
- Keras/TF setup and no TF1 sessions, esp. in new ModelZoo
- changes to binarizer and its CLI (`mode`, `overwrite`, `run_single()`)
- writer: `build...` w/ kwargs instead of positional
- training for segmentation/binarization/enhancement tasks:
  * drop unused `generate_data_from_folder()`
  * simplify `preprocess_imgs()`: turn `preprocess_img()`, `get_patches()`
    and `get_patches_num_scale_new()` into generators, only writing
    result files in the caller (top-level loop) instead of passing
    output directories and file counter
- training for new OCR task:
  * `train`: put keys into additional `config_params` where they belong,
    resp. (conditioned under existing keys), and w/ better documentation
  * `train`: add new keys as kwargs to `run()` to make usable
  * `utils`: instead of custom data loader `data_gen_ocr()`, re-use
    existing `preprocess_imgs()` (for cfg capture and top-level loop),
    but extended w/ new kwargs and calling new `preprocess_img_ocr()`;
    the latter as single-image generator (also much simplified)
  * `train`: use tf.data loader pipeline from that generator w/ standard
    mechanisms for batching, shuffling, prefetching etc.
  * `utils` and `train`: instead of `vectorize_label`, use `Dataset.padded_batch`
  * add TensorBoard callback and re-use our checkpoint callback
  * also use standard Keras top-level loop for training

still problematic (substantially unresolved):
- `Patches` now only w/ fixed implicit size
  (ignoring training config params)
- `PatchEncoder` now only w/ fixed implicit num patches and projection dim
  (ignoring training config params)
This commit is contained in:
Robert Sachunsky 2026-02-07 14:05:56 +01:00
commit 27f43c175f
77 changed files with 5597 additions and 4952 deletions

View file

@ -1,589 +0,0 @@
import sys
import click
import logging
from ocrd_utils import initLogging, getLevelName, getLogger
from eynollah.eynollah import Eynollah, Eynollah_ocr
from eynollah.sbb_binarize import SbbBinarizer
from eynollah.image_enhancer import Enhancer
from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout
@click.group()
def main():
pass
@main.command()
@click.option(
"--input",
"-i",
help="PAGE-XML input filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--dir_in",
"-di",
help="directory of PAGE-XML input files (instead of --input)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--out",
"-o",
help="directory for output images",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--model",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)
def machine_based_reading_order(input, dir_in, out, model, log_level):
assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
orderer = machine_based_reading_order_on_layout(model)
if log_level:
orderer.logger.setLevel(getLevelName(log_level))
orderer.run(xml_filename=input,
dir_in=dir_in,
dir_out=out,
)
@main.command()
@click.option('--patches/--no-patches', default=True, help='by enabling this parameter you let the model to see the image in patches.')
@click.option('--model_dir', '-m', type=click.Path(exists=True, file_okay=False), required=True, help='directory containing models for prediction')
@click.option(
"--input-image", "--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False)
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--output",
"-o",
help="output image (if using -i) or output image directory (if using -di)",
type=click.Path(file_okay=True, dir_okay=True),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)
def binarization(patches, model_dir, input_image, dir_in, output, overwrite, log_level):
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
binarizer = SbbBinarizer(model_dir)
if log_level:
binarizer.logger.setLevel(getLevelName(log_level))
binarizer.run(overwrite=overwrite,
use_patches=patches,
image_path=input_image,
output=output,
dir_in=dir_in)
@main.command()
@click.option(
"--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--out",
"-o",
help="directory for output PAGE-XML files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--model",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
)
@click.option(
"--save_org_scale/--no_save_org_scale",
"-sos/-nosos",
is_flag=True,
help="if this parameter set to true, this tool will save the enhanced image in org scale.",
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)
def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_lower, save_org_scale, log_level):
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
initLogging()
enhancer = Enhancer(
model,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
save_org_scale=save_org_scale,
)
if log_level:
enhancer.logger.setLevel(getLevelName(log_level))
enhancer.run(overwrite=overwrite,
dir_in=dir_in,
image_filename=image,
dir_out=out,
)
@main.command()
@click.option(
"--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--out",
"-o",
help="directory for output PAGE-XML files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--model",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--model_version",
"-mv",
help="override default versions of model categories",
type=(str, str),
multiple=True,
)
@click.option(
"--save_images",
"-si",
help="if a directory is given, images in documents will be cropped and saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_layout",
"-sl",
help="if a directory is given, plot of layout will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_deskewed",
"-sd",
help="if a directory is given, deskewed image will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_all",
"-sa",
help="if a directory is given, all plots needed for documentation will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_page",
"-sp",
help="if a directory is given, page crop of image will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--enable-plotting/--disable-plotting",
"-ep/-noep",
is_flag=True,
help="If set, will plot intermediary files and images",
)
@click.option(
"--extract_only_images/--disable-extracting_only_images",
"-eoi/-noeoi",
is_flag=True,
help="If a directory is given, only images in documents will be cropped and saved there and the other processing will not be done",
)
@click.option(
"--allow-enhancement/--no-allow-enhancement",
"-ae/-noae",
is_flag=True,
help="if this parameter set to true, this tool would check that input image need resizing and enhancement or not. If so output of resized and enhanced image and corresponding layout data will be written in out directory",
)
@click.option(
"--curved-line/--no-curvedline",
"-cl/-nocl",
is_flag=True,
help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline. This should be taken into account that with this option the tool need more time to do process.",
)
@click.option(
"--textline_light/--no-textline_light",
"-tll/-notll",
is_flag=True,
help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline with a faster method.",
)
@click.option(
"--full-layout/--no-full-layout",
"-fl/-nofl",
is_flag=True,
help="if this parameter set to true, this tool will try to return all elements of layout.",
)
@click.option(
"--tables/--no-tables",
"-tab/-notab",
is_flag=True,
help="if this parameter set to true, this tool will try to detect tables.",
)
@click.option(
"--right2left/--left2right",
"-r2l/-l2r",
is_flag=True,
help="if this parameter set to true, this tool will extract right-to-left reading order.",
)
@click.option(
"--input_binary/--input-RGB",
"-ib/-irgb",
is_flag=True,
help="in general, eynollah uses RGB as input but if the input document is strongly dark, bright or for any other reason you can turn binarized input on. This option does not mean that you have to provide a binary image, otherwise this means that the tool itself will binarized the RGB input document.",
)
@click.option(
"--allow_scaling/--no-allow-scaling",
"-as/-noas",
is_flag=True,
help="if this parameter set to true, this tool would check the scale and if needed it will scale it to perform better layout detection",
)
@click.option(
"--headers_off/--headers-on",
"-ho/-noho",
is_flag=True,
help="if this parameter set to true, this tool would ignore headers role in reading order",
)
@click.option(
"--light_version/--original",
"-light/-org",
is_flag=True,
help="if this parameter set to true, this tool would use lighter version",
)
@click.option(
"--ignore_page_extraction/--extract_page_included",
"-ipe/-epi",
is_flag=True,
help="if this parameter set to true, this tool would ignore page extraction",
)
@click.option(
"--reading_order_machine_based/--heuristic_reading_order",
"-romb/-hro",
is_flag=True,
help="if this parameter set to true, this tool would apply machine based reading order detection",
)
@click.option(
"--do_ocr",
"-ocr/-noocr",
is_flag=True,
help="if this parameter set to true, this tool will try to do ocr",
)
@click.option(
"--transformer_ocr",
"-tr/-notr",
is_flag=True,
help="if this parameter set to true, this tool will apply transformer ocr",
)
@click.option(
"--batch_size_ocr",
"-bs_ocr",
help="number of inference batch size of ocr model. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
)
@click.option(
"--threshold_art_class_layout",
"-tharl",
help="threshold of artifical class in the case of layout detection. The default value is 0.1",
)
@click.option(
"--threshold_art_class_textline",
"-thart",
help="threshold of artifical class in the case of textline detection. The default value is 0.1",
)
@click.option(
"--skip_layout_and_reading_order",
"-slro/-noslro",
is_flag=True,
help="if this parameter set to true, this tool will ignore layout detection and reading order. It means that textline detection will be done within printspace and contours of textline will be written in xml output file.",
)
# TODO move to top-level CLI context
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override 'eynollah' log level globally to this",
)
#
@click.option(
"--setup-logging",
is_flag=True,
help="Setup a basic console logger",
)
def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging):
if setup_logging:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
console_handler.setFormatter(formatter)
getLogger('eynollah').addHandler(console_handler)
getLogger('eynollah').setLevel(logging.INFO)
else:
initLogging()
assert enable_plotting or not save_layout, "Plotting with -sl also requires -ep"
assert enable_plotting or not save_deskewed, "Plotting with -sd also requires -ep"
assert enable_plotting or not save_all, "Plotting with -sa also requires -ep"
assert enable_plotting or not save_page, "Plotting with -sp also requires -ep"
assert enable_plotting or not save_images, "Plotting with -si also requires -ep"
assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep"
assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \
"Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae"
assert textline_light == light_version, "Both light textline detection -tll and light version -light must be set or unset equally"
assert not extract_only_images or not allow_enhancement, "Image extraction -eoi can not be set alongside allow_enhancement -ae"
assert not extract_only_images or not allow_scaling, "Image extraction -eoi can not be set alongside allow_scaling -as"
assert not extract_only_images or not light_version, "Image extraction -eoi can not be set alongside light_version -light"
assert not extract_only_images or not curved_line, "Image extraction -eoi can not be set alongside curved_line -cl"
assert not extract_only_images or not textline_light, "Image extraction -eoi can not be set alongside textline_light -tll"
assert not extract_only_images or not full_layout, "Image extraction -eoi can not be set alongside full_layout -fl"
assert not extract_only_images or not tables, "Image extraction -eoi can not be set alongside tables -tab"
assert not extract_only_images or not right2left, "Image extraction -eoi can not be set alongside right2left -r2l"
assert not extract_only_images or not headers_off, "Image extraction -eoi can not be set alongside headers_off -ho"
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah(
model,
model_versions=model_version,
extract_only_images=extract_only_images,
enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement,
curved_line=curved_line,
textline_light=textline_light,
full_layout=full_layout,
tables=tables,
right2left=right2left,
input_binary=input_binary,
allow_scaling=allow_scaling,
headers_off=headers_off,
light_version=light_version,
ignore_page_extraction=ignore_page_extraction,
reading_order_machine_based=reading_order_machine_based,
do_ocr=do_ocr,
transformer_ocr=transformer_ocr,
batch_size_ocr=batch_size_ocr,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
skip_layout_and_reading_order=skip_layout_and_reading_order,
threshold_art_class_textline=threshold_art_class_textline,
threshold_art_class_layout=threshold_art_class_layout,
)
if log_level:
eynollah.logger.setLevel(getLevelName(log_level))
eynollah.run(overwrite=overwrite,
image_filename=image,
dir_in=dir_in,
dir_out=out,
dir_of_cropped_images=save_images,
dir_of_layout=save_layout,
dir_of_deskewed=save_deskewed,
dir_of_all=save_all,
dir_save_page=save_page,
)
@main.command()
@click.option(
"--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--dir_in_bin",
"-dib",
help="directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png' suffix).\nPerform prediction using both RGB and binary images. (This does not necessarily improve results, however it may be beneficial for certain document images.)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--dir_xmls",
"-dx",
help="directory of input PAGE-XML files (in addition to --dir_in; filename stems must match the image files, with '.xml' suffix).",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--out",
"-o",
help="directory for output PAGE-XML files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--dir_out_image_text",
"-doit",
help="directory for output images, newly rendered with predicted text",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--model",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--model_name",
help="Specific model file path to use for OCR",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--tr_ocr",
"-trocr/-notrocr",
is_flag=True,
help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.",
)
@click.option(
"--export_textline_images_and_text",
"-etit/-noetit",
is_flag=True,
help="if this parameter set to true, images and text in xml will be exported into output dir. This files can be used for training a OCR engine.",
)
@click.option(
"--do_not_mask_with_textline_contour",
"-nmtc/-mtc",
is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
)
@click.option(
"--batch_size",
"-bs",
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
)
@click.option(
"--dataset_abbrevation",
"-ds_pref",
help="in the case of extracting textline and text from a xml GT file user can add an abbrevation of dataset name to generated dataset",
)
@click.option(
"--min_conf_value_of_textline_text",
"-min_conf",
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)
def ocr(image, dir_in, dir_in_bin, dir_xmls, out, dir_out_image_text, overwrite, model, model_name, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, batch_size, dataset_abbrevation, min_conf_value_of_textline_text, log_level):
initLogging()
assert bool(model) != bool(model_name), "Either -m (model directory) or --model_name (specific model name) must be provided."
assert not export_textline_images_and_text or not tr_ocr, "Exporting textline and text -etit can not be set alongside transformer ocr -tr_ocr"
assert not export_textline_images_and_text or not model, "Exporting textline and text -etit can not be set alongside model -m"
assert not export_textline_images_and_text or not batch_size, "Exporting textline and text -etit can not be set alongside batch size -bs"
assert not export_textline_images_and_text or not dir_in_bin, "Exporting textline and text -etit can not be set alongside directory of bin images -dib"
assert not export_textline_images_and_text or not dir_out_image_text, "Exporting textline and text -etit can not be set alongside directory of images with predicted text -doit"
assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
eynollah_ocr = Eynollah_ocr(
dir_models=model,
model_name=model_name,
tr_ocr=tr_ocr,
export_textline_images_and_text=export_textline_images_and_text,
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
batch_size=batch_size,
pref_of_dataset=dataset_abbrevation,
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
)
if log_level:
eynollah_ocr.logger.setLevel(getLevelName(log_level))
eynollah_ocr.run(overwrite=overwrite,
dir_in=dir_in,
dir_in_bin=dir_in_bin,
image_filename=image,
dir_xmls=dir_xmls,
dir_out_image_text=dir_out_image_text,
dir_out=out,
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,22 @@
# NOTE: For predictable order of imports of torch/shapely/tensorflow
# this must be the first import of the CLI!
from ..eynollah_imports import imported_libs
from .cli_models import models_cli
from .cli_binarize import binarize_cli
from .cli import main
from .cli_binarize import binarize_cli
from .cli_enhance import enhance_cli
from .cli_extract_images import extract_images_cli
from .cli_layout import layout_cli
from .cli_ocr import ocr_cli
from .cli_readingorder import readingorder_cli
main.add_command(binarize_cli, 'binarization')
main.add_command(enhance_cli, 'enhancement')
main.add_command(layout_cli, 'layout')
main.add_command(readingorder_cli, 'machine-based-reading-order')
main.add_command(models_cli, 'models')
main.add_command(ocr_cli, 'ocr')
main.add_command(extract_images_cli, 'extract-images')

66
src/eynollah/cli/cli.py Normal file
View file

@ -0,0 +1,66 @@
from dataclasses import dataclass
import logging
import sys
import os
from typing import Union
import click
from ..model_zoo import EynollahModelZoo
from .cli_models import models_cli
@dataclass()
class EynollahCliCtx:
"""
Holds options relevant for all eynollah subcommands
"""
model_zoo: EynollahModelZoo
log_level : Union[str, None] = 'INFO'
@click.group()
@click.option(
"--model-basedir",
"-m",
help="directory of models",
# NOTE: not mandatory to exist so --help for subcommands works but will log a warning
# and raise exception when trying to load models in the CLI
# type=click.Path(exists=True),
default=f'{os.getcwd()}/models_eynollah',
)
@click.option(
"--model-overrides",
"-mv",
help="override default versions of model categories, syntax is 'CATEGORY VARIANT PATH', e.g 'region light /path/to/model'. See eynollah list-models for the full list",
type=(str, str, str),
multiple=True,
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)
@click.pass_context
def main(ctx, model_basedir, model_overrides, log_level):
"""
eynollah - Document Layout Analysis, Image Enhancement, OCR
"""
# Initialize logging
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setLevel(logging.NOTSET)
formatter = logging.Formatter('%(asctime)s.%(msecs)03d %(levelname)s %(name)s - %(message)s', datefmt='%H:%M:%S')
console_handler.setFormatter(formatter)
logging.getLogger('eynollah').addHandler(console_handler)
logging.getLogger('eynollah').setLevel(log_level or logging.INFO)
# Initialize model zoo
model_zoo = EynollahModelZoo(basedir=model_basedir, model_overrides=model_overrides)
# Initialize CLI context
ctx.obj = EynollahCliCtx(
model_zoo=model_zoo,
log_level=log_level,
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,52 @@
import click
@click.command()
@click.option('--patches/--no-patches', default=True, help='by enabling this parameter you let the model to see the image in patches.')
@click.option(
"--input-image", "--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False)
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--output",
"-o",
help="output image (if using -i) or output image directory (if using -di)",
type=click.Path(file_okay=True, dir_okay=True),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.pass_context
def binarize_cli(
ctx,
patches,
input_image,
dir_in,
output,
overwrite,
):
"""
Binarize images with a ML model
"""
from ..sbb_binarize import SbbBinarizer
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo)
binarizer.run(
image_path=input_image,
use_patches=patches,
output=output,
dir_in=dir_in,
overwrite=overwrite
)

View file

@ -0,0 +1,63 @@
import click
@click.command()
@click.option(
"--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--out",
"-o",
help="directory for output PAGE-XML files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
)
@click.option(
"--save_org_scale/--no_save_org_scale",
"-sos/-nosos",
is_flag=True,
help="if this parameter set to true, this tool will save the enhanced image in org scale.",
)
@click.pass_context
def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale):
"""
Enhance image
"""
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
from ..image_enhancer import Enhancer
enhancer = Enhancer(
model_zoo=ctx.obj.model_zoo,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
save_org_scale=save_org_scale,
)
enhancer.run(overwrite=overwrite,
dir_in=dir_in,
image_filename=image,
dir_out=out,
)

View file

@ -0,0 +1,100 @@
import click
@click.command()
@click.option(
"--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--out",
"-o",
help="directory for output PAGE-XML files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_images",
"-si",
help="if a directory is given, images in documents will be cropped and saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--enable-plotting/--disable-plotting",
"-ep/-noep",
is_flag=True,
help="If set, will plot intermediary files and images",
)
@click.option(
"--input_binary/--input-RGB",
"-ib/-irgb",
is_flag=True,
help="In general, eynollah uses RGB as input but if the input document is very dark, very bright or for any other reason you can turn on input binarization. When this flag is set, eynollah will binarize the RGB input document, you should always provide RGB images to eynollah.",
)
@click.option(
"--ignore_page_extraction/--extract_page_included",
"-ipe/-epi",
is_flag=True,
help="if this parameter set to true, this tool would ignore page extraction",
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
)
@click.pass_context
def extract_images_cli(
ctx,
image,
out,
overwrite,
dir_in,
save_images,
enable_plotting,
input_binary,
num_col_upper,
num_col_lower,
ignore_page_extraction,
):
"""
Detect Layout (with optional image enhancement and reading order detection)
"""
assert enable_plotting or not save_images, "Plotting with -si also requires -ep"
assert not enable_plotting or save_images, "Plotting with -ep also requires -si"
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
from ..extract_images import EynollahImageExtractor
extractor = EynollahImageExtractor(
model_zoo=ctx.obj.model_zoo,
enable_plotting=enable_plotting,
input_binary=input_binary,
ignore_page_extraction=ignore_page_extraction,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
)
extractor.run(overwrite=overwrite,
image_filename=image,
dir_in=dir_in,
dir_out=out,
dir_of_cropped_images=save_images,
)

View file

@ -0,0 +1,223 @@
import click
@click.command()
@click.option(
"--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--out",
"-o",
help="directory for output PAGE-XML files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_images",
"-si",
help="if a directory is given, images in documents will be cropped and saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_layout",
"-sl",
help="if a directory is given, plot of layout will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_deskewed",
"-sd",
help="if a directory is given, deskewed image will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_all",
"-sa",
help="if a directory is given, all plots needed for documentation will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--save_page",
"-sp",
help="if a directory is given, page crop of image will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--enable-plotting/--disable-plotting",
"-ep/-noep",
is_flag=True,
help="If set, will plot intermediary files and images",
)
@click.option(
"--allow-enhancement/--no-allow-enhancement",
"-ae/-noae",
is_flag=True,
help="if this parameter set to true, this tool would check that input image need resizing and enhancement or not. If so output of resized and enhanced image and corresponding layout data will be written in out directory",
)
@click.option(
"--curved-line/--no-curvedline",
"-cl/-nocl",
is_flag=True,
help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline. This should be taken into account that with this option the tool need more time to do process.",
)
@click.option(
"--full-layout/--no-full-layout",
"-fl/-nofl",
is_flag=True,
help="if this parameter set to true, this tool will try to return all elements of layout.",
)
@click.option(
"--tables/--no-tables",
"-tab/-notab",
is_flag=True,
help="if this parameter set to true, this tool will try to detect tables.",
)
@click.option(
"--right2left/--left2right",
"-r2l/-l2r",
is_flag=True,
help="if this parameter set to true, this tool will extract right-to-left reading order.",
)
@click.option(
"--input_binary/--input-RGB",
"-ib/-irgb",
is_flag=True,
help="In general, eynollah uses RGB as input but if the input document is very dark, very bright or for any other reason you can turn on input binarization. When this flag is set, eynollah will binarize the RGB input document, you should always provide RGB images to eynollah.",
)
@click.option(
"--allow_scaling/--no-allow-scaling",
"-as/-noas",
is_flag=True,
help="if this parameter set to true, this tool would check the scale and if needed it will scale it to perform better layout detection",
)
@click.option(
"--headers_off/--headers-on",
"-ho/-noho",
is_flag=True,
help="if this parameter set to true, this tool would ignore headers role in reading order",
)
@click.option(
"--ignore_page_extraction/--extract_page_included",
"-ipe/-epi",
is_flag=True,
help="if this parameter set to true, this tool would ignore page extraction",
)
@click.option(
"--reading_order_machine_based/--heuristic_reading_order",
"-romb/-hro",
is_flag=True,
help="if this parameter set to true, this tool would apply machine based reading order detection",
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
)
@click.option(
"--threshold_art_class_layout",
"-tharl",
help="threshold of artifical class in the case of layout detection. The default value is 0.1",
)
@click.option(
"--threshold_art_class_textline",
"-thart",
help="threshold of artifical class in the case of textline detection. The default value is 0.1",
)
@click.option(
"--skip_layout_and_reading_order",
"-slro/-noslro",
is_flag=True,
help="if this parameter set to true, this tool will ignore layout detection and reading order. It means that textline detection will be done within printspace and contours of textline will be written in xml output file.",
)
@click.pass_context
def layout_cli(
ctx,
image,
out,
overwrite,
dir_in,
save_images,
save_layout,
save_deskewed,
save_all,
save_page,
enable_plotting,
allow_enhancement,
curved_line,
full_layout,
tables,
right2left,
input_binary,
allow_scaling,
headers_off,
reading_order_machine_based,
num_col_upper,
num_col_lower,
threshold_art_class_textline,
threshold_art_class_layout,
skip_layout_and_reading_order,
ignore_page_extraction,
):
"""
Detect Layout (with optional image enhancement and reading order detection)
"""
from ..eynollah import Eynollah
assert enable_plotting or not save_layout, "Plotting with -sl also requires -ep"
assert enable_plotting or not save_deskewed, "Plotting with -sd also requires -ep"
assert enable_plotting or not save_all, "Plotting with -sa also requires -ep"
assert enable_plotting or not save_page, "Plotting with -sp also requires -ep"
assert enable_plotting or not save_images, "Plotting with -si also requires -ep"
assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep"
assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \
"Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae"
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah(
model_zoo=ctx.obj.model_zoo,
enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement,
curved_line=curved_line,
full_layout=full_layout,
tables=tables,
right2left=right2left,
input_binary=input_binary,
allow_scaling=allow_scaling,
headers_off=headers_off,
ignore_page_extraction=ignore_page_extraction,
reading_order_machine_based=reading_order_machine_based,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
skip_layout_and_reading_order=skip_layout_and_reading_order,
threshold_art_class_textline=threshold_art_class_textline,
threshold_art_class_layout=threshold_art_class_layout,
)
eynollah.run(overwrite=overwrite,
image_filename=image,
dir_in=dir_in,
dir_out=out,
dir_of_cropped_images=save_images,
dir_of_layout=save_layout,
dir_of_deskewed=save_deskewed,
dir_of_all=save_all,
dir_save_page=save_page,
)

View file

@ -0,0 +1,69 @@
from pathlib import Path
from typing import Set, Tuple
import click
from eynollah.model_zoo.default_specs import MODELS_VERSION
@click.group()
@click.pass_context
def models_cli(
ctx,
):
"""
Organize models for the various runners in eynollah.
"""
assert ctx.obj.model_zoo
@models_cli.command('list')
@click.pass_context
def list_models(
ctx,
):
"""
List all the models in the zoo
"""
print(f"Model basedir: {ctx.obj.model_zoo.model_basedir}")
print(f"Model overrides: {ctx.obj.model_zoo.model_overrides}")
print(ctx.obj.model_zoo)
@models_cli.command('package')
@click.option(
'--set-version', '-V', 'version', help="Version to use for packaging", default=MODELS_VERSION, show_default=True
)
@click.argument('output_dir')
@click.pass_context
def package(
ctx,
version,
output_dir,
):
"""
Generate shell code to copy all the models in the zoo into properly named folders in OUTPUT_DIR for distribution.
eynollah models -m SRC package OUTPUT_DIR
SRC should contain a directory "models_eynollah" containing all the models.
"""
mkdirs: Set[Path] = set([])
copies: Set[Tuple[Path, Path]] = set([])
for spec in ctx.obj.model_zoo.specs.specs:
# skip these as they are dependent on the ocr model
if spec.category in ('num_to_char', 'characters'):
continue
src: Path = ctx.obj.model_zoo.model_path(spec.category, spec.variant)
# Only copy the top-most directory relative to models_eynollah
while src.parent.name != 'models_eynollah':
src = src.parent
for dist in spec.dists:
dist_dir = Path(f"{output_dir}/models_{dist}_{version}/models_eynollah")
copies.add((src, dist_dir))
mkdirs.add(dist_dir)
for dir in mkdirs:
print(f"mkdir -vp {dir}")
for (src, dst) in copies:
print(f"cp -vr {src} {dst}")
for dir in mkdirs:
zip_path = Path(f'../{dir.parent.name}.zip')
print(f"(cd {dir}/..; zip -vr {zip_path} models_eynollah)")

103
src/eynollah/cli/cli_ocr.py Normal file
View file

@ -0,0 +1,103 @@
import click
@click.command()
@click.option(
"--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--dir_in_bin",
"-dib",
help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png' \n Perform prediction using both RGB and binary images. (This does not necessarily improve results, however it may be beneficial for certain document images."),
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--dir_xmls",
"-dx",
help="directory of input PAGE-XML files (in addition to --dir_in; filename stems must match the image files, with '.xml' suffix).",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--out",
"-o",
help="directory for output PAGE-XML files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--dir_out_image_text",
"-doit",
help="directory for output images, newly rendered with predicted text",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--tr_ocr",
"-trocr/-notrocr",
is_flag=True,
help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.",
)
@click.option(
"--do_not_mask_with_textline_contour",
"-nmtc/-mtc",
is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
)
@click.option(
"--batch_size",
"-bs",
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
)
@click.option(
"--min_conf_value_of_textline_text",
"-min_conf",
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
)
@click.pass_context
def ocr_cli(
ctx,
image,
dir_in,
dir_in_bin,
dir_xmls,
out,
dir_out_image_text,
overwrite,
tr_ocr,
do_not_mask_with_textline_contour,
batch_size,
min_conf_value_of_textline_text,
):
"""
Recognize text with a CNN/RNN or transformer ML model.
"""
assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
from ..eynollah_ocr import Eynollah_ocr
eynollah_ocr = Eynollah_ocr(
model_zoo=ctx.obj.model_zoo,
tr_ocr=tr_ocr,
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
batch_size=batch_size,
min_conf_value_of_textline_text=min_conf_value_of_textline_text)
eynollah_ocr.run(overwrite=overwrite,
dir_in=dir_in,
dir_in_bin=dir_in_bin,
image_filename=image,
dir_xmls=dir_xmls,
dir_out_image_text=dir_out_image_text,
dir_out=out,
)

View file

@ -0,0 +1,35 @@
import click
@click.command()
@click.option(
"--input",
"-i",
help="PAGE-XML input filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--dir_in",
"-di",
help="directory of PAGE-XML input files (instead of --input)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--out",
"-o",
help="directory for output images",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.pass_context
def readingorder_cli(ctx, input, dir_in, out):
"""
Generate ReadingOrder with a ML model
"""
from ..mb_ro_on_layout import machine_based_reading_order_on_layout
assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
orderer = machine_based_reading_order_on_layout(model_zoo=ctx.obj.model_zoo)
orderer.run(xml_filename=input,
dir_in=dir_in,
dir_out=out,
)

View file

@ -0,0 +1,281 @@
"""
extract images?
"""
from concurrent.futures import ProcessPoolExecutor
import logging
from multiprocessing import cpu_count
import os
import time
from typing import Optional
from pathlib import Path
import tensorflow as tf
import numpy as np
import cv2
from eynollah.utils.contour import filter_contours_area_of_image, return_contours_of_image, return_contours_of_interested_region
from eynollah.utils.resize import resize_image
from .model_zoo.model_zoo import EynollahModelZoo
from .eynollah import Eynollah
from .utils import box2rect, is_image_filename
from .plot import EynollahPlotter
class EynollahImageExtractor(Eynollah):
def __init__(
self,
*,
model_zoo: EynollahModelZoo,
enable_plotting : bool = False,
input_binary : bool = False,
ignore_page_extraction : bool = False,
num_col_upper : Optional[int] = None,
num_col_lower : Optional[int] = None,
full_layout : bool = False,
tables : bool = False,
curved_line : bool = False,
allow_enhancement : bool = False,
):
self.logger = logging.getLogger('eynollah.extract_images')
self.model_zoo = model_zoo
self.plotter = None
self.tables = tables
self.curved_line = curved_line
self.allow_enhancement = allow_enhancement
self.enable_plotting = enable_plotting
# --input-binary sensible if image is very dark, if layout is not working.
self.input_binary = input_binary
self.ignore_page_extraction = ignore_page_extraction
self.full_layout = full_layout
if num_col_upper:
self.num_col_upper = int(num_col_upper)
else:
self.num_col_upper = num_col_upper
if num_col_lower:
self.num_col_lower = int(num_col_lower)
else:
self.num_col_lower = num_col_lower
# for parallelization of CPU-intensive tasks:
self.executor = ProcessPoolExecutor(max_workers=cpu_count())
t_start = time.time()
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
self.logger.info("Loading models...")
self.setup_models()
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
def setup_models(self):
loadable = [
"col_classifier",
"binarization",
"page",
"extract_images",
]
self.model_zoo.load_models(*loadable)
def get_regions_light_v_extract_only_images(self,img, num_col_classifier):
self.logger.debug("enter get_regions_extract_images_only")
erosion_hurts = False
img_org = np.copy(img)
img_height_h = img_org.shape[0]
img_width_h = img_org.shape[1]
if num_col_classifier == 1:
img_w_new = 700
elif num_col_classifier == 2:
img_w_new = 900
elif num_col_classifier == 3:
img_w_new = 1500
elif num_col_classifier == 4:
img_w_new = 1800
elif num_col_classifier == 5:
img_w_new = 2200
elif num_col_classifier == 6:
img_w_new = 2500
else:
raise ValueError("num_col_classifier must be in range 1..6")
img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
img_resized = resize_image(img,img_h_new, img_w_new )
prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_zoo.get("extract_images"))
prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h )
image_page, page_coord, cont_page = self.extract_page()
prediction_regions_org = prediction_regions_org[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
prediction_regions_org=prediction_regions_org[:,:,0]
mask_seps_only = (prediction_regions_org[:,:] ==3)*1
mask_texts_only = (prediction_regions_org[:,:] ==1)*1
mask_images_only=(prediction_regions_org[:,:] ==2)*1
polygons_seplines, hir_seplines = return_contours_of_image(mask_seps_only)
polygons_seplines = filter_contours_area_of_image(
mask_seps_only, polygons_seplines, hir_seplines, max_area=1, min_area=0.00001, dilate=1)
polygons_of_only_texts = return_contours_of_interested_region(mask_texts_only,1,0.00001)
polygons_of_only_seps = return_contours_of_interested_region(mask_seps_only,1,0.00001)
text_regions_p_true = np.zeros(prediction_regions_org.shape)
text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_seps, color=(3,3,3))
text_regions_p_true[:,:][mask_images_only[:,:] == 1] = 2
text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts=polygons_of_only_texts, color=(1,1,1))
text_regions_p_true[text_regions_p_true.shape[0]-15:text_regions_p_true.shape[0], :] = 0
text_regions_p_true[:, text_regions_p_true.shape[1]-15:text_regions_p_true.shape[1]] = 0
##polygons_of_images = return_contours_of_interested_region(text_regions_p_true, 2, 0.0001)
polygons_of_images = return_contours_of_interested_region(text_regions_p_true, 2, 0.001)
polygons_of_images_fin = []
for ploy_img_ind in polygons_of_images:
box = _, _, w, h = cv2.boundingRect(ploy_img_ind)
if h < 150 or w < 150:
pass
else:
page_coord_img = box2rect(box) # type: ignore
polygons_of_images_fin.append(np.array([[page_coord_img[2], page_coord_img[0]],
[page_coord_img[3], page_coord_img[0]],
[page_coord_img[3], page_coord_img[1]],
[page_coord_img[2], page_coord_img[1]]]))
self.logger.debug("exit get_regions_extract_images_only")
return (text_regions_p_true,
erosion_hurts,
polygons_seplines,
polygons_of_images_fin,
image_page,
page_coord,
cont_page)
def run(self,
overwrite: bool = False,
image_filename: Optional[str] = None,
dir_in: Optional[str] = None,
dir_out: Optional[str] = None,
dir_of_cropped_images: Optional[str] = None,
dir_of_layout: Optional[str] = None,
dir_of_deskewed: Optional[str] = None,
dir_of_all: Optional[str] = None,
dir_save_page: Optional[str] = None,
):
"""
Get image and scales, then extract the page of scanned image
"""
self.logger.debug("enter run")
t0_tot = time.time()
# Log enabled features directly
enabled_modes = []
if self.full_layout:
enabled_modes.append("Full layout analysis")
if self.tables:
enabled_modes.append("Table detection")
if enabled_modes:
self.logger.info("Enabled modes: " + ", ".join(enabled_modes))
if self.enable_plotting:
self.logger.info("Saving debug plots")
if dir_of_cropped_images:
self.logger.info(f"Saving cropped images to: {dir_of_cropped_images}")
if dir_of_layout:
self.logger.info(f"Saving layout plots to: {dir_of_layout}")
if dir_of_deskewed:
self.logger.info(f"Saving deskewed images to: {dir_of_deskewed}")
if dir_in:
ls_imgs = [os.path.join(dir_in, image_filename)
for image_filename in filter(is_image_filename,
os.listdir(dir_in))]
elif image_filename:
ls_imgs = [image_filename]
else:
raise ValueError("run requires either a single image filename or a directory")
for img_filename in ls_imgs:
self.logger.info(img_filename)
t0 = time.time()
self.reset_file_name_dir(img_filename, dir_out)
if self.enable_plotting:
self.plotter = EynollahPlotter(dir_out=dir_out,
dir_of_all=dir_of_all,
dir_save_page=dir_save_page,
dir_of_deskewed=dir_of_deskewed,
dir_of_cropped_images=dir_of_cropped_images,
dir_of_layout=dir_of_layout,
image_filename_stem=Path(img_filename).stem)
#print("text region early -11 in %.1fs", time.time() - t0)
if os.path.exists(self.writer.output_filename):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", self.writer.output_filename)
else:
self.logger.warning("will skip input for existing output file '%s'", self.writer.output_filename)
continue
pcgts = self.run_single()
self.logger.info("Job done in %.1fs", time.time() - t0)
self.writer.write_pagexml(pcgts)
if dir_in:
self.logger.info("All jobs done in %.1fs", time.time() - t0_tot)
def run_single(self):
t0 = time.time()
self.logger.info(f"Processing file: {self.writer.image_filename}")
self.logger.info("Step 1/5: Image Enhancement")
img_res, is_image_enhanced, num_col_classifier, _ = \
self.run_enhancement()
self.logger.info(f"Image: {self.image.shape[1]}x{self.image.shape[0]}, "
f"{self.dpi} DPI, {num_col_classifier} columns")
if is_image_enhanced:
self.logger.info("Enhancement applied")
self.logger.info(f"Enhancement complete ({time.time() - t0:.1f}s)")
# Image Extraction Mode
self.logger.info("Step 2/5: Image Extraction Mode")
_, _, _, polygons_of_images, \
image_page, page_coord, cont_page = \
self.get_regions_light_v_extract_only_images(img_res, num_col_classifier)
pcgts = self.writer.build_pagexml_no_full_layout(
found_polygons_text_region=[],
page_coord=page_coord,
order_of_texts=[],
all_found_textline_polygons=[],
all_box_coord=[],
found_polygons_text_region_img=polygons_of_images,
found_polygons_marginals_left=[],
found_polygons_marginals_right=[],
all_found_textline_polygons_marginals_left=[],
all_found_textline_polygons_marginals_right=[],
all_box_coord_marginals_left=[],
all_box_coord_marginals_right=[],
slopes=[],
slopes_marginals_left=[],
slopes_marginals_right=[],
cont_page=cont_page,
polygons_seplines=[],
found_polygons_tables=[],
)
if self.plotter:
self.plotter.write_images_into_directory(polygons_of_images, image_page)
self.logger.info("Image extraction complete")
return pcgts

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,10 @@
"""
Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
"""
from ocrd_utils import tf_disable_interactive_logs
from torch import *
tf_disable_interactive_logs()
import tensorflow.keras
from shapely import *
imported_libs = True
__all__ = ['imported_libs']

View file

@ -0,0 +1,837 @@
# FIXME: fix all of those...
# pyright: reportOptionalSubscript=false
from logging import Logger, getLogger
from typing import List, Optional
from pathlib import Path
import os
import gc
import math
from dataclasses import dataclass
import cv2
from cv2.typing import MatLike
from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw
import numpy as np
from eynollah.model_zoo import EynollahModelZoo
from eynollah.utils.font import get_font
from eynollah.utils.xml import etree_namespace_for_element_tag
try:
import torch
except ImportError:
torch = None
from .utils import is_image_filename
from .utils.resize import resize_image
from .utils.utils_ocr import (
break_curved_line_into_small_pieces_and_then_merge,
decode_batch_predictions,
fit_text_single_line,
get_contours_and_bounding_boxes,
get_orientation_moments,
preprocess_and_resize_image_for_ocrcnn_model,
return_textlines_split_if_needed,
rotate_image_with_padding,
)
# TODO: refine typing
@dataclass
class EynollahOcrResult:
extracted_texts_merged: List
extracted_conf_value_merged: Optional[List]
cropped_lines_region_indexer: List
total_bb_coordinates:List
class Eynollah_ocr:
def __init__(
self,
*,
model_zoo: EynollahModelZoo,
tr_ocr=False,
batch_size: Optional[int]=None,
do_not_mask_with_textline_contour: bool=False,
min_conf_value_of_textline_text : Optional[float]=None,
logger: Optional[Logger]=None,
):
self.tr_ocr = tr_ocr
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
self.logger = logger if logger else getLogger('eynollah.ocr')
self.model_zoo = model_zoo
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
if tr_ocr:
self.model_zoo.load_model('trocr_processor')
self.model_zoo.load_model('ocr', 'tr')
self.model_zoo.get('ocr').to(self.device)
else:
self.model_zoo.load_model('ocr', '')
self.model_zoo.load_model('num_to_char')
self.model_zoo.load_model('characters')
self.end_character = len(self.model_zoo.get('characters', list)) + 2
@property
def device(self):
assert torch
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
return torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
return torch.device("cpu")
def run_trocr(
self,
*,
img: MatLike,
page_tree: ET.ElementTree,
page_ns,
tr_ocr_input_height_and_width,
) -> EynollahOcrResult:
total_bb_coordinates = []
cropped_lines = []
cropped_lines_region_indexer = []
cropped_lines_meging_indexing = []
extracted_texts = []
indexer_text_region = 0
indexer_b_s = 0
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h=child_textlines.attrib['points'].split(' ')
textline_coords = np.array( [ [int(x.split(',')[0]),
int(x.split(',')[1]) ]
for x in p_h] )
x,y,w,h = cv2.boundingRect(textline_coords)
total_bb_coordinates.append([x,y,w,h])
h2w_ratio = h/float(w)
img_poly_on_img = np.copy(img)
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y:y+h, x:x+w, :]
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
img_crop[mask_poly==0] = 255
self.logger.debug("processing %d lines for '%s'",
len(cropped_lines), nn.attrib['id'])
if h2w_ratio > 0.1:
cropped_lines.append(resize_image(img_crop,
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width) )
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
else:
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
#print(splited_images)
if splited_images:
cropped_lines.append(resize_image(splited_images[0],
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
cropped_lines.append(resize_image(splited_images[1],
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(-1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
else:
cropped_lines.append(img_crop)
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
indexer_text_region = indexer_text_region +1
if indexer_b_s!=0:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
####extracted_texts = []
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
####for i in range(n_iterations):
####if i==(n_iterations-1):
####n_start = i*self.b_s
####imgs = cropped_lines[n_start:]
####else:
####n_start = i*self.b_s
####n_end = (i+1)*self.b_s
####imgs = cropped_lines[n_start:n_end]
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
####generated_ids_merged = self.model_ocr.generate(
#### pixel_values_merged.to(self.device))
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
#### generated_ids_merged, skip_special_tokens=True)
####extracted_texts = extracted_texts + generated_text_merged
del cropped_lines
gc.collect()
extracted_texts_merged = [extracted_texts[ind]
if cropped_lines_meging_indexing[ind]==0
else extracted_texts[ind]+" "+extracted_texts[ind+1]
if cropped_lines_meging_indexing[ind]==1
else None
for ind in range(len(cropped_lines_meging_indexing))]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
#print(extracted_texts_merged, len(extracted_texts_merged))
return EynollahOcrResult(
extracted_texts_merged=extracted_texts_merged,
extracted_conf_value_merged=None,
cropped_lines_region_indexer=cropped_lines_region_indexer,
total_bb_coordinates=total_bb_coordinates,
)
def run_cnn(
self,
*,
img: MatLike,
img_bin: Optional[MatLike],
page_tree: ET.ElementTree,
page_ns,
image_width,
image_height,
) -> EynollahOcrResult:
total_bb_coordinates = []
cropped_lines = []
img_crop_bin = None
imgs_bin = None
imgs_bin_ver_flipped = None
cropped_lines_bin = []
cropped_lines_ver_index = []
cropped_lines_region_indexer = []
cropped_lines_meging_indexing = []
indexer_text_region = 0
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
try:
type_textregion = nn.attrib['type']
except:
type_textregion = 'paragraph'
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h=child_textlines.attrib['points'].split(' ')
textline_coords = np.array( [ [int(x.split(',')[0]),
int(x.split(',')[1]) ]
for x in p_h] )
x,y,w,h = cv2.boundingRect(textline_coords)
angle_radians = math.atan2(h, w)
# Convert to degrees
angle_degrees = math.degrees(angle_radians)
if type_textregion=='drop-capital':
angle_degrees = 0
total_bb_coordinates.append([x,y,w,h])
w_scaled = w * image_height/float(h)
img_poly_on_img = np.copy(img)
if img_bin:
img_poly_on_img_bin = np.copy(img_bin)
img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :]
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y:y+h, x:x+w, :]
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
# print(file_name, angle_degrees, w*h,
# mask_poly[:,:,0].sum(),
# mask_poly[:,:,0].sum() /float(w*h) ,
# 'didi')
if angle_degrees > 3:
better_des_slope = get_orientation_moments(textline_coords)
img_crop = rotate_image_with_padding(img_crop, better_des_slope)
if img_bin:
img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope)
mask_poly = rotate_image_with_padding(mask_poly, better_des_slope)
mask_poly = mask_poly.astype('uint8')
#new bounding box
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0])
mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :]
img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :]
if not self.do_not_mask_with_textline_contour:
img_crop[mask_poly==0] = 255
if img_bin:
img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :]
if not self.do_not_mask_with_textline_contour:
img_crop_bin[mask_poly==0] = 255
if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90:
if img_bin:
img_crop, img_crop_bin = \
break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly, img_crop_bin)
else:
img_crop, _ = \
break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly)
else:
better_des_slope = 0
if not self.do_not_mask_with_textline_contour:
img_crop[mask_poly==0] = 255
if img_bin:
if not self.do_not_mask_with_textline_contour:
img_crop_bin[mask_poly==0] = 255
if type_textregion=='drop-capital':
pass
else:
if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90:
if img_bin:
img_crop, img_crop_bin = \
break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly, img_crop_bin)
else:
img_crop, _ = \
break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly)
if w_scaled < 750:#1.5*image_width:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop, image_height, image_width)
cropped_lines.append(img_fin)
if abs(better_des_slope) > 45:
cropped_lines_ver_index.append(1)
else:
cropped_lines_ver_index.append(0)
cropped_lines_meging_indexing.append(0)
if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop_bin, image_height, image_width)
cropped_lines_bin.append(img_fin)
else:
splited_images, splited_images_bin = return_textlines_split_if_needed(
img_crop, img_crop_bin if img_bin else None)
if splited_images:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images[0], image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(1)
if abs(better_des_slope) > 45:
cropped_lines_ver_index.append(1)
else:
cropped_lines_ver_index.append(0)
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images[1], image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(-1)
if abs(better_des_slope) > 45:
cropped_lines_ver_index.append(1)
else:
cropped_lines_ver_index.append(0)
if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images_bin[0], image_height, image_width)
cropped_lines_bin.append(img_fin)
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images_bin[1], image_height, image_width)
cropped_lines_bin.append(img_fin)
else:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop, image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(0)
if abs(better_des_slope) > 45:
cropped_lines_ver_index.append(1)
else:
cropped_lines_ver_index.append(0)
if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop_bin, image_height, image_width)
cropped_lines_bin.append(img_fin)
indexer_text_region = indexer_text_region +1
extracted_texts = []
extracted_conf_value = []
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
# FIXME: copy pasta
for i in range(n_iterations):
if i==(n_iterations-1):
n_start = i*self.b_s
imgs = cropped_lines[n_start:]
imgs = np.array(imgs)
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
ver_imgs = np.array( cropped_lines_ver_index[n_start:] )
indices_ver = np.where(ver_imgs == 1)[0]
#print(indices_ver, 'indices_ver')
if len(indices_ver)>0:
imgs_ver_flipped = imgs[indices_ver, : ,: ,:]
imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:]
#print(imgs_ver_flipped, 'imgs_ver_flipped')
else:
imgs_ver_flipped = None
if img_bin:
imgs_bin = cropped_lines_bin[n_start:]
imgs_bin = np.array(imgs_bin)
imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3)
if len(indices_ver)>0:
imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:]
imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:]
#print(imgs_ver_flipped, 'imgs_ver_flipped')
else:
imgs_bin_ver_flipped = None
else:
n_start = i*self.b_s
n_end = (i+1)*self.b_s
imgs = cropped_lines[n_start:n_end]
imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3)
ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] )
indices_ver = np.where(ver_imgs == 1)[0]
#print(indices_ver, 'indices_ver')
if len(indices_ver)>0:
imgs_ver_flipped = imgs[indices_ver, : ,: ,:]
imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:]
#print(imgs_ver_flipped, 'imgs_ver_flipped')
else:
imgs_ver_flipped = None
if img_bin:
imgs_bin = cropped_lines_bin[n_start:n_end]
imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3)
if len(indices_ver)>0:
imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:]
imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:]
#print(imgs_ver_flipped, 'imgs_ver_flipped')
else:
imgs_bin_ver_flipped = None
self.logger.debug("processing next %d lines", len(imgs))
preds = self.model_zoo.get('ocr').predict(imgs, verbose=0)
if len(indices_ver)>0:
preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0)
preds_max_fliped = np.max(preds_flipped, axis=2 )
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
masked_means_flipped = \
np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \
np.sum(pred_max_not_unk_mask_bool_flipped, axis=1)
masked_means_flipped[np.isnan(masked_means_flipped)] = 0
preds_max = np.max(preds, axis=2 )
preds_max_args = np.argmax(preds, axis=2 )
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
masked_means = \
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
np.sum(pred_max_not_unk_mask_bool, axis=1)
masked_means[np.isnan(masked_means)] = 0
masked_means_ver = masked_means[indices_ver]
#print(masked_means_ver, 'pred_max_not_unk')
indices_where_flipped_conf_value_is_higher = \
np.where(masked_means_flipped > masked_means_ver)[0]
#print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher')
if len(indices_where_flipped_conf_value_is_higher)>0:
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
preds[indices_to_be_replaced,:,:] = \
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
if img_bin:
preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0)
if len(indices_ver)>0:
preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0)
preds_max_fliped = np.max(preds_flipped, axis=2 )
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
masked_means_flipped = \
np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \
np.sum(pred_max_not_unk_mask_bool_flipped, axis=1)
masked_means_flipped[np.isnan(masked_means_flipped)] = 0
preds_max = np.max(preds, axis=2 )
preds_max_args = np.argmax(preds, axis=2 )
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
masked_means = \
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
np.sum(pred_max_not_unk_mask_bool, axis=1)
masked_means[np.isnan(masked_means)] = 0
masked_means_ver = masked_means[indices_ver]
#print(masked_means_ver, 'pred_max_not_unk')
indices_where_flipped_conf_value_is_higher = \
np.where(masked_means_flipped > masked_means_ver)[0]
#print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher')
if len(indices_where_flipped_conf_value_is_higher)>0:
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
preds_bin[indices_to_be_replaced,:,:] = \
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
preds = (preds + preds_bin) / 2.
pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char'))
preds_max = np.max(preds, axis=2 )
preds_max_args = np.argmax(preds, axis=2 )
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
masked_means = \
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
np.sum(pred_max_not_unk_mask_bool, axis=1)
for ib in range(imgs.shape[0]):
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
if masked_means[ib] >= self.min_conf_value_of_textline_text:
extracted_texts.append(pred_texts_ib)
extracted_conf_value.append(masked_means[ib])
else:
extracted_texts.append("")
extracted_conf_value.append(0)
del cropped_lines
del cropped_lines_bin
gc.collect()
extracted_texts_merged = [extracted_texts[ind]
if cropped_lines_meging_indexing[ind]==0
else extracted_texts[ind]+" "+extracted_texts[ind+1]
if cropped_lines_meging_indexing[ind]==1
else None
for ind in range(len(cropped_lines_meging_indexing))]
extracted_conf_value_merged = [extracted_conf_value[ind] # type: ignore
if cropped_lines_meging_indexing[ind]==0
else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2.
if cropped_lines_meging_indexing[ind]==1
else None
for ind in range(len(cropped_lines_meging_indexing))]
extracted_conf_value_merged: List[float] = [extracted_conf_value_merged[ind_cfm]
for ind_cfm in range(len(extracted_texts_merged))
if extracted_texts_merged[ind_cfm] is not None]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
return EynollahOcrResult(
extracted_texts_merged=extracted_texts_merged,
extracted_conf_value_merged=extracted_conf_value_merged,
cropped_lines_region_indexer=cropped_lines_region_indexer,
total_bb_coordinates=total_bb_coordinates,
)
def write_ocr(
self,
*,
result: EynollahOcrResult,
page_tree: ET.ElementTree,
out_file_ocr,
page_ns,
img,
out_image_with_text,
):
cropped_lines_region_indexer = result.cropped_lines_region_indexer
total_bb_coordinates = result.total_bb_coordinates
extracted_texts_merged = result.extracted_texts_merged
extracted_conf_value_merged = result.extracted_conf_value_merged
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text)
font = get_font()
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0]
y_bb = bb_ind[1]
w_bb = bb_ind[2]
h_bb = bb_ind[3]
font = fit_text_single_line(draw, extracted_texts_merged[indexer_text],
font.path, w_bb, int(h_bb*0.4) )
##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2)
text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally
text_y = y_bb + (h_bb - text_height) // 2 # Center vertically
# Draw the text
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
image_text.save(out_image_with_text)
text_by_textregion = []
for ind in unique_cropped_lines_region_indexer:
ind = np.array(cropped_lines_region_indexer)==ind
extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
if len(extracted_texts_merged_un)>1:
text_by_textregion_ind = ""
next_glue = ""
for indt in range(len(extracted_texts_merged_un)):
if (extracted_texts_merged_un[indt].endswith('') or
extracted_texts_merged_un[indt].endswith('-') or
extracted_texts_merged_un[indt].endswith('¬')):
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1]
next_glue = ""
else:
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt]
next_glue = " "
text_by_textregion.append(text_by_textregion_ind)
else:
text_by_textregion.append(" ".join(extracted_texts_merged_un))
indexer = 0
indexer_textregion = 0
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
is_textregion_text = False
for childtest in nn:
if childtest.tag.endswith("TextEquiv"):
is_textregion_text = True
if not is_textregion_text:
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
has_textline = False
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
is_textline_text = False
for childtest2 in child_textregion:
if childtest2.tag.endswith("TextEquiv"):
is_textline_text = True
if not is_textline_text:
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
if extracted_conf_value_merged:
text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
unicode_textline.text = extracted_texts_merged[indexer]
else:
for childtest3 in child_textregion:
if childtest3.tag.endswith("TextEquiv"):
for child_uc in childtest3:
if child_uc.tag.endswith("Unicode"):
if extracted_conf_value_merged:
childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
child_uc.text = extracted_texts_merged[indexer]
indexer = indexer + 1
has_textline = True
if has_textline:
if is_textregion_text:
for child4 in nn:
if child4.tag.endswith("TextEquiv"):
for childtr_uc in child4:
if childtr_uc.tag.endswith("Unicode"):
childtr_uc.text = text_by_textregion[indexer_textregion]
else:
unicode_textregion.text = text_by_textregion[indexer_textregion]
indexer_textregion = indexer_textregion + 1
ET.register_namespace("",page_ns)
page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None)
def run(
self,
*,
overwrite: bool = False,
dir_in: Optional[str] = None,
dir_in_bin: Optional[str] = None,
image_filename: Optional[str] = None,
dir_xmls: str,
dir_out_image_text: Optional[str] = None,
dir_out: str,
):
"""
Run OCR.
Args:
dir_in_bin (str): Prediction with RGB and binarized images for selected pages, should not be the default
"""
if dir_in:
ls_imgs = [os.path.join(dir_in, image_filename)
for image_filename in filter(is_image_filename,
os.listdir(dir_in))]
else:
assert image_filename
ls_imgs = [image_filename]
for img_filename in ls_imgs:
file_stem = Path(img_filename).stem
page_file_in = os.path.join(dir_xmls, file_stem+'.xml')
out_file_ocr = os.path.join(dir_out, file_stem+'.xml')
if os.path.exists(out_file_ocr):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", out_file_ocr)
else:
self.logger.warning("will skip input for existing output file '%s'", out_file_ocr)
return
img = cv2.imread(img_filename)
page_tree = ET.parse(page_file_in, parser = ET.XMLParser(encoding="utf-8"))
page_ns = etree_namespace_for_element_tag(page_tree.getroot().tag)
out_image_with_text = None
if dir_out_image_text:
out_image_with_text = os.path.join(dir_out_image_text, file_stem + '.png')
img_bin = None
if dir_in_bin:
img_bin = cv2.imread(os.path.join(dir_in_bin, file_stem+'.png'))
if self.tr_ocr:
result = self.run_trocr(
img=img,
page_tree=page_tree,
page_ns=page_ns,
tr_ocr_input_height_and_width = 384
)
else:
result = self.run_cnn(
img=img,
page_tree=page_tree,
page_ns=page_ns,
img_bin=img_bin,
image_width=512,
image_height=32,
)
self.write_ocr(
result=result,
img=img,
page_tree=page_tree,
page_ns=page_ns,
out_file_ocr=out_file_ocr,
out_image_with_text=out_image_with_text,
)

View file

@ -2,7 +2,12 @@
Image enhancer. The output can be written as same scale of input or in new predicted scale.
"""
from logging import Logger
# FIXME: fix all of those...
# pyright: reportUnboundVariable=false
# pyright: reportCallIssue=false
# pyright: reportArgumentType=false
import logging
import os
import time
from typing import Optional
@ -10,19 +15,18 @@ from pathlib import Path
import gc
import cv2
from keras.models import Model
import numpy as np
from ocrd_utils import getLogger, tf_disable_interactive_logs
import tensorflow as tf
import tensorflow as tf # type: ignore
from skimage.morphology import skeletonize
from tensorflow.keras.models import load_model
from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image
from .utils.pil_cv2 import pil2cv
from .utils import (
is_image_filename,
crop_image_inside_box
)
from .eynollah import PatchEncoder, Patches
DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8)
@ -31,14 +35,13 @@ KERNEL = np.ones((5, 5), np.uint8)
class Enhancer:
def __init__(
self,
dir_models : str,
*,
model_zoo: EynollahModelZoo,
num_col_upper : Optional[int] = None,
num_col_lower : Optional[int] = None,
save_org_scale : bool = False,
logger : Optional[Logger] = None,
):
self.input_binary = False
self.light_version = False
self.save_org_scale = save_org_scale
if num_col_upper:
self.num_col_upper = int(num_col_upper)
@ -49,12 +52,10 @@ class Enhancer:
else:
self.num_col_lower = num_col_lower
self.logger = logger if logger else getLogger('enhancement')
self.dir_models = dir_models
self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425"
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915"
self.logger = logging.getLogger('eynollah.enhance')
self.model_zoo = model_zoo
for v in ['binarization', 'enhancement', 'col_classifier', 'page']:
self.model_zoo.load_model(v)
try:
for device in tf.config.list_physical_devices('GPU'):
@ -62,25 +63,14 @@ class Enhancer:
except:
self.logger.warning("no GPU device available")
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {}
if image_filename:
ret['img'] = cv2.imread(image_filename)
if self.light_version:
self.dpi = 100
else:
self.dpi = 0#check_dpi(image_filename)
self.dpi = 100
else:
ret['img'] = pil2cv(image_pil)
if self.light_version:
self.dpi = 100
else:
self.dpi = 0#check_dpi(image_pil)
self.dpi = 100
ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY)
for prefix in ('', '_grayscale'):
ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8)
@ -100,26 +90,11 @@ class Enhancer:
key += '_uint8'
return self._imgs[key].copy()
def isNaN(self, num):
return num != num
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement")
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1]
img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2]
if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < img_width_model:
@ -160,7 +135,7 @@ class Enhancer:
index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose='0')
seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0:
@ -246,7 +221,7 @@ class Enhancer:
else:
img = self.imread()
img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page'))
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
@ -285,13 +260,13 @@ class Enhancer:
return img_new, num_column_is_classified
def resize_and_enhance_image_with_column_classifier(self, light_version):
def resize_and_enhance_image_with_column_classifier(self):
self.logger.debug("enter resize_and_enhance_image_with_column_classifier")
dpi = 0#self.dpi
self.logger.info("Detected %s DPI", dpi)
if self.input_binary:
img = self.imread()
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
img= np.copy(prediction_bin)
@ -332,7 +307,7 @@ class Enhancer:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
if self.input_binary:
@ -352,7 +327,7 @@ class Enhancer:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
if num_col > self.num_col_upper:
@ -368,16 +343,13 @@ class Enhancer:
self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))
if dpi < DPI_THRESHOLD:
if light_version and num_col in (1,2):
if num_col in (1,2):
img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2(
img, num_col, width_early, label_p_pred)
else:
img_new, num_column_is_classified = self.calculate_width_height_by_columns(
img, num_col, width_early, label_p_pred)
if light_version:
image_res = np.copy(img_new)
else:
image_res = self.predict_enhancement(img_new)
image_res = np.copy(img_new)
is_image_enhanced = True
else:
@ -671,11 +643,11 @@ class Enhancer:
gc.collect()
return prediction_true
def run_enhancement(self, light_version):
def run_enhancement(self):
t_in = time.time()
self.logger.info("Resizing and enhancing image...")
is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \
self.resize_and_enhance_image_with_column_classifier(light_version)
self.resize_and_enhance_image_with_column_classifier()
self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ')
return img_res, is_image_enhanced, num_col_classifier, num_column_is_classified
@ -683,9 +655,9 @@ class Enhancer:
def run_single(self):
t0 = time.time()
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False)
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement()
return img_res
return img_res, is_image_enhanced
def run(self,
@ -723,9 +695,18 @@ class Enhancer:
self.logger.warning("will skip input for existing output file '%s'", self.output_filename)
continue
image_enhanced = self.run_single()
did_resize = False
image_enhanced, did_enhance = self.run_single()
if self.save_org_scale:
image_enhanced = resize_image(image_enhanced, self.h_org, self.w_org)
did_resize = True
self.logger.info(
"Image %s was %senhanced%s.",
img_filename,
'' if did_enhance else 'not ',
'and resized' if did_resize else ''
)
cv2.imwrite(self.output_filename, image_enhanced)

View file

@ -1,8 +1,12 @@
"""
Image enhancer. The output can be written as same scale of input or in new predicted scale.
Machine learning based reading order detection
"""
from logging import Logger
# pyright: reportCallIssue=false
# pyright: reportUnboundVariable=false
# pyright: reportArgumentType=false
import logging
import os
import time
from typing import Optional
@ -10,12 +14,12 @@ from pathlib import Path
import xml.etree.ElementTree as ET
import cv2
from keras.models import Model
import numpy as np
from ocrd_utils import getLogger
import statistics
import tensorflow as tf
from tensorflow.keras.models import load_model
from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image
from .utils.contour import (
find_new_features_of_contours,
@ -23,7 +27,6 @@ from .utils.contour import (
return_parent_contours,
)
from .utils import is_xml_filename
from .eynollah import PatchEncoder, Patches
DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8)
@ -32,12 +35,12 @@ KERNEL = np.ones((5, 5), np.uint8)
class machine_based_reading_order_on_layout:
def __init__(
self,
dir_models : str,
logger : Optional[Logger] = None,
*,
model_zoo: EynollahModelZoo,
logger : Optional[logging.Logger] = None,
):
self.logger = logger if logger else getLogger('mbreorder')
self.dir_models = dir_models
self.model_reading_order_dir = dir_models + "/model_eynollah_reading_order_20250824"
self.logger = logger or logging.getLogger('eynollah.mbreorder')
self.model_zoo = model_zoo
try:
for device in tf.config.list_physical_devices('GPU'):
@ -45,20 +48,7 @@ class machine_based_reading_order_on_layout:
except:
self.logger.warning("no GPU device available")
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
self.light_version = True
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
self.model_zoo.load_model('reading_order')
def read_xml(self, xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
@ -69,6 +59,7 @@ class machine_based_reading_order_on_layout:
index_tot_regions = []
tot_region_ref = []
y_len, x_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
@ -81,13 +72,13 @@ class machine_based_reading_order_on_layout:
co_printspace = []
if link+'PrintSpace' in alltags:
region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')])
elif link+'Border' in alltags:
else:
region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')])
for tag in region_tags_printspace:
if link+'PrintSpace' in alltags:
tag_endings_printspace = ['}PrintSpace','}printspace']
elif link+'Border' in alltags:
else:
tag_endings_printspace = ['}Border','}border']
if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]):
@ -524,7 +515,7 @@ class machine_based_reading_order_on_layout:
min_cont_size_to_be_dilated = 10
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _ = find_new_features_of_contours(contours_only_text_parent)
args_cont_located = np.array(range(len(contours_only_text_parent)))
@ -624,13 +615,13 @@ class machine_based_reading_order_on_layout:
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,
int(x_min_main[j]):int(x_max_main[j])] = 1
co_text_all_org = contours_only_text_parent + contours_only_text_parent_h
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
co_text_all = contours_only_dilated + contours_only_text_parent_h
else:
co_text_all = contours_only_text_parent + contours_only_text_parent_h
else:
co_text_all_org = contours_only_text_parent
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
co_text_all = contours_only_dilated
else:
co_text_all = contours_only_text_parent
@ -683,7 +674,7 @@ class machine_based_reading_order_on_layout:
tot_counter += 1
batch.append(j)
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
y_pr = self.model_reading_order.predict(input_1 , verbose=0)
y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0')
for jb, j in enumerate(batch):
if y_pr[jb][0]>=0.5:
post_list.append(j)
@ -709,7 +700,7 @@ class machine_based_reading_order_on_layout:
##id_all_text = np.array(id_all_text)[index_sort]
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
org_contours_indexes = []
for ind in range(len(ordered)):
region_with_curr_order = ordered[ind]
@ -802,6 +793,7 @@ class machine_based_reading_order_on_layout:
alltags=[elem.tag for elem in root_xml.iter()]
ET.register_namespace("",name_space)
assert dir_out
tree_xml.write(os.path.join(dir_out, file_name+'.xml'),
xml_declaration=True,
method='xml',

View file

@ -0,0 +1,4 @@
__all__ = [
'EynollahModelZoo',
]
from .model_zoo import EynollahModelZoo

View file

@ -0,0 +1,252 @@
from .specs import EynollahModelSpec, EynollahModelSpecSet
# NOTE: This needs to change whenever models/versions change
ZENODO = "https://zenodo.org/records/17295988/files"
MODELS_VERSION = "v0_7_0"
def dist_url(dist_name: str="layout") -> str:
return f'{ZENODO}/models_{dist_name}_{MODELS_VERSION}.zip'
DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
EynollahModelSpec(
category="enhancement",
variant='',
filename="models_eynollah/eynollah-enhancement_20210425",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="binarization",
variant='hybrid',
filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="binarization",
variant='20210309',
filename="models_eynollah/eynollah-binarization_20210309",
dist_url=dist_url("extra"),
type='Keras',
),
EynollahModelSpec(
category="binarization",
variant='',
filename="models_eynollah/eynollah-binarization_20210425",
dist_url=dist_url("extra"),
type='Keras',
),
EynollahModelSpec(
category="col_classifier",
variant='',
filename="models_eynollah/eynollah-column-classifier_20210425",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="page",
variant='',
filename="models_eynollah/model_eynollah_page_extraction_20250915",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="region",
variant='',
filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="extract_images",
variant='',
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="region",
variant='',
filename="models_eynollah/eynollah-main-regions_20220314",
dist_url=dist_url(),
help="early layout",
type='Keras',
),
EynollahModelSpec(
category="region_p2",
variant='non-light',
filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425",
dist_url=dist_url('extra'),
help="early layout, non-light, 2nd part",
type='Keras',
),
EynollahModelSpec(
category="region_1_2",
variant='',
#filename="models_eynollah/modelens_12sp_elay_0_3_4__3_6_n",
#filename="models_eynollah/modelens_earlylayout_12spaltige_2_3_5_6_7_8",
#filename="models_eynollah/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18",
#filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige",
#filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige",
filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024",
dist_url=dist_url("layout"),
help="early layout, light, 1-or-2-column",
type='Keras',
),
EynollahModelSpec(
category="region_fl_np",
variant='',
#'filename="models_eynollah/modelens_full_lay_1_3_031124",
#'filename="models_eynollah/modelens_full_lay_13__3_19_241024",
#'filename="models_eynollah/model_full_lay_13_241024",
#'filename="models_eynollah/modelens_full_lay_13_17_231024",
#'filename="models_eynollah/modelens_full_lay_1_2_221024",
#'filename="models_eynollah/eynollah-full-regions-1column_20210425",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist_url=dist_url(),
help="full layout / no patches",
type='Keras',
),
# FIXME: Why is region_fl and region_fl_np the same model?
EynollahModelSpec(
category="region_fl",
variant='',
# filename="models_eynollah/eynollah-full-regions-3+column_20210425",
# filename="models_eynollah/model_2_full_layout_new_trans",
# filename="models_eynollah/modelens_full_lay_1_3_031124",
# filename="models_eynollah/modelens_full_lay_13__3_19_241024",
# filename="models_eynollah/model_full_lay_13_241024",
# filename="models_eynollah/modelens_full_lay_13_17_231024",
# filename="models_eynollah/modelens_full_lay_1_2_221024",
# filename="models_eynollah/modelens_full_layout_24_till_28",
# filename="models_eynollah/model_2_full_layout_new_trans",
filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist_url=dist_url(),
help="full layout / with patches",
type='Keras',
),
EynollahModelSpec(
category="reading_order",
variant='',
#filename="models_eynollah/model_mb_ro_aug_ens_11",
#filename="models_eynollah/model_step_3200000_mb_ro",
#filename="models_eynollah/model_ens_reading_order_machine_based",
#filename="models_eynollah/model_mb_ro_aug_ens_8",
#filename="models_eynollah/model_ens_reading_order_machine_based",
filename="models_eynollah/model_eynollah_reading_order_20250824",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="textline",
variant='non-light',
#filename="models_eynollah/modelens_textline_1_4_16092024",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_1_3_4_20240915",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_9_12_13_14_15",
#filename="models_eynollah/eynollah-textline_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url('extra'),
type='Keras',
),
EynollahModelSpec(
category="textline",
variant='',
#filename="models_eynollah/eynollah-textline_light_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="table",
variant='non-light',
filename="models_eynollah/eynollah-tables_20210319",
dist_url=dist_url('extra'),
type='Keras',
),
EynollahModelSpec(
category="table",
variant='',
filename="models_eynollah/modelens_table_0t4_201124",
dist_url=dist_url(),
type='Keras',
),
EynollahModelSpec(
category="ocr",
variant='',
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
dist_url=dist_url("ocr"),
type='Keras',
),
EynollahModelSpec(
category="ocr",
variant='degraded',
filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/",
help="slightly better at degraded Fraktur",
dist_url=dist_url("ocr"),
type='Keras',
),
EynollahModelSpec(
category="num_to_char",
variant='',
filename="characters_org.txt",
dist_url=dist_url("ocr"),
type='decoder',
),
EynollahModelSpec(
category="characters",
variant='',
filename="characters_org.txt",
dist_url=dist_url("ocr"),
type='List[str]',
),
EynollahModelSpec(
category="ocr",
variant='tr',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist_url=dist_url("ocr"),
help='much slower transformer-based',
type='Keras',
),
EynollahModelSpec(
category="trocr_processor",
variant='',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist_url=dist_url("ocr"),
type='TrOCRProcessor',
),
EynollahModelSpec(
category="trocr_processor",
variant='htr',
filename="models_eynollah/microsoft/trocr-base-handwritten",
dist_url=dist_url("extra"),
type='TrOCRProcessor',
),
])

View file

@ -0,0 +1,206 @@
import os
import json
import logging
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
from tensorflow.keras.layers import StringLookup
from tensorflow.keras.models import Model as KerasModel
from tensorflow.keras.models import load_model
from tabulate import tabulate
from ..patch_encoder import PatchEncoder, Patches
from .specs import EynollahModelSpecSet
from .default_specs import DEFAULT_MODEL_SPECS
from .types import AnyModel, T
class EynollahModelZoo:
"""
Wrapper class that handles storage and loading of models for all eynollah runners.
"""
model_basedir: Path
specs: EynollahModelSpecSet
def __init__(
self,
basedir: str,
model_overrides: Optional[List[Tuple[str, str, str]]] = None,
) -> None:
self.model_basedir = Path(basedir)
self.logger = logging.getLogger('eynollah.model_zoo')
if not self.model_basedir.exists():
self.logger.warning(f"Model basedir does not exist: {basedir}. Set eynollah --model-basedir to the correct directory.")
self.specs = deepcopy(DEFAULT_MODEL_SPECS)
self._overrides = []
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[str, AnyModel] = {}
@property
def model_overrides(self):
return self._overrides
def override_models(
self,
*model_overrides: Tuple[str, str, str],
):
"""
Override the default model versions
"""
for model_category, model_variant, model_filename in model_overrides:
spec = self.specs.get(model_category, model_variant)
self.logger.warning("Overriding filename for model spec %s to %s", spec, model_filename)
self.specs.get(model_category, model_variant).filename = model_filename
self._overrides += model_overrides
def model_path(
self,
model_category: str,
model_variant: str = '',
absolute: bool = True,
) -> Path:
"""
Translate model_{type,variant} tuple into an absolute (or relative) Path
"""
spec = self.specs.get(model_category, model_variant)
if spec.category in ('characters', 'num_to_char'):
return self.model_path('ocr') / spec.filename
if not Path(spec.filename).is_absolute() and absolute:
model_path = Path(self.model_basedir).joinpath(spec.filename)
else:
model_path = Path(spec.filename)
return model_path
def load_models(
self,
*all_load_args: Union[str, Tuple[str], Tuple[str, str], Tuple[str, str, str]],
) -> Dict:
"""
Load all models by calling load_model and return a dictionary mapping model_category to loaded model
"""
ret = {}
for load_args in all_load_args:
if isinstance(load_args, str):
ret[load_args] = self.load_model(load_args)
else:
ret[load_args[0]] = self.load_model(*load_args)
return ret
def load_model(
self,
model_category: str,
model_variant: str = '',
model_path_override: Optional[str] = None,
) -> AnyModel:
"""
Load any model
"""
if model_path_override:
self.override_models((model_category, model_variant, model_path_override))
model_path = self.model_path(model_category, model_variant)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
if model_category == 'ocr':
model = self._load_ocr_model(variant=model_variant)
elif model_category == 'num_to_char':
model = self._load_num_to_char()
elif model_category == 'characters':
model = self._load_characters()
elif model_category == 'trocr_processor':
from transformers import TrOCRProcessor
model = TrOCRProcessor.from_pretrained(model_path)
else:
try:
model = load_model(model_path, compile=False)
except Exception as e:
self.logger.exception(e)
model = load_model(
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
)
self._loaded[model_category] = model
return model # type: ignore
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T:
if model_category not in self._loaded:
raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"')
ret = self._loaded[model_category]
if model_type:
assert isinstance(ret, model_type)
return ret # type: ignore # FIXME: convince typing that we're returning generic type
def _load_ocr_model(self, variant: str) -> AnyModel:
"""
Load OCR model
"""
ocr_model_dir = self.model_path('ocr', variant)
if variant == 'tr':
from transformers import VisionEncoderDecoderModel
ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
assert isinstance(ret, VisionEncoderDecoderModel)
return ret
else:
ocr_model = load_model(ocr_model_dir, compile=False)
assert isinstance(ocr_model, KerasModel)
return KerasModel(
ocr_model.get_layer(name="image").input, # type: ignore
ocr_model.get_layer(name="dense2").output, # type: ignore
)
def _load_characters(self) -> List[str]:
"""
Load encoding for OCR
"""
with open(self.model_path('num_to_char'), "r") as config_file:
return json.load(config_file)
def _load_num_to_char(self) -> StringLookup:
"""
Load decoder for OCR
"""
characters = self._load_characters()
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
# Mapping integers back to original characters.
return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
def __str__(self):
return tabulate(
[
[
spec.type,
spec.category,
spec.variant,
spec.help,
f'Yes, at {self.model_path(spec.category, spec.variant)}'
if self.model_path(spec.category, spec.variant).exists()
else f'No, download {spec.dist_url}',
# self.model_path(spec.category, spec.variant),
]
for spec in sorted(self.specs.specs, key=lambda x: x.dist_url)
],
headers=[
'Type',
'Category',
'Variant',
'Help',
'Used in',
'Installed',
],
tablefmt='github',
)
def shutdown(self):
"""
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
"""
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in list(self._loaded.keys()):
del self._loaded[needle]

View file

@ -0,0 +1,52 @@
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple
@dataclass
class EynollahModelSpec():
"""
Describing a single model abstractly.
"""
category: str
# Relative filename to the models_eynollah directory in the dists
filename: str
# URL to the smallest model distribution containing this model (link to Zenodo)
dist_url: str
type: str
variant: str = ''
help: str = ''
class EynollahModelSpecSet():
"""
List of all used models for eynollah.
"""
specs: List[EynollahModelSpec]
def __init__(self, specs: List[EynollahModelSpec]) -> None:
self.specs = sorted(specs, key=lambda x: x.category + '0' + x.variant)
self.categories: Set[str] = set([spec.category for spec in self.specs])
self.variants: Dict[str, Set[str]] = {
spec.category: set([x.variant for x in self.specs if x.category == spec.category])
for spec in self.specs
}
self._index_category_variant: Dict[Tuple[str, str], EynollahModelSpec] = {
(spec.category, spec.variant): spec
for spec in self.specs
}
def asdict(self) -> Dict[str, Dict[str, str]]:
return {
spec.category: {
spec.variant: spec.filename
}
for spec in self.specs
}
def get(self, category: str, variant: str) -> EynollahModelSpec:
if category not in self.categories:
raise ValueError(f"Unknown category '{category}', must be one of {self.categories}")
if variant not in self.variants[category]:
raise ValueError(f"Unknown variant {variant} for {category}. Known variants: {self.variants[category]}")
return self._index_category_variant[(category, variant)]

View file

@ -0,0 +1,7 @@
from typing import TypeVar
# NOTE: Creating an actual union type requires loading transformers which is expensive and error-prone
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# AnyModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, KerasModel, List]
AnyModel = object
T = TypeVar('T')

View file

@ -1,5 +1,5 @@
{
"version": "0.6.0",
"version": "0.7.0",
"git_url": "https://github.com/qurator-spk/eynollah",
"dockerhub": "ocrd/eynollah",
"tools": {
@ -29,16 +29,6 @@
"type": "boolean",
"default": true,
"description": "Try to detect all element subtypes, including drop-caps and headings"
},
"light_version": {
"type": "boolean",
"default": true,
"description": "Try to detect all element subtypes in light version (faster+simpler method for main region detection and deskewing)"
},
"textline_light": {
"type": "boolean",
"default": true,
"description": "Light version need textline light. If this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline with a faster method."
},
"tables": {
"type": "boolean",
@ -83,12 +73,20 @@
},
"resources": [
{
"url": "https://zenodo.org/records/17194824/files/models_layout_v0_5_0.tar.gz?download=1",
"name": "models_layout_v0_5_0",
"url": "https://zenodo.org/records/17580627/files/models_all_v0_7_0.zip?download=1",
"name": "models_layout_v0_7_0",
"type": "archive",
"path_in_archive": "models_layout_v0_5_0",
"size": 6119874002,
"description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement and OCR",
"version_range": ">= v0.7.0"
},
{
"url": "https://zenodo.org/records/17295988/files/models_layout_v0_6_0.tar.gz?download=1",
"name": "models_layout_v0_6_0",
"type": "archive",
"path_in_archive": "models_layout_v0_6_0",
"size": 3525684179,
"description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement",
"description": "Models for layout detection, reading order detection, textline detection, page extraction, column classification, table detection, binarization, image enhancement and OCR",
"version_range": ">= v0.5.0"
},
{

View file

@ -1,3 +1,6 @@
# NOTE: For predictable order of imports of torch/shapely/tensorflow
# this must be the first import of the CLI!
from .eynollah_imports import imported_libs
from .processor import EynollahProcessor
from click import command
from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor

View file

@ -1,6 +1,8 @@
from functools import cached_property
from typing import Optional
from PIL import Image
from frozendict import frozendict
import numpy as np
import cv2
from click import command
@ -9,6 +11,8 @@ from ocrd import Processor, OcrdPageResult, OcrdPageResultImage
from ocrd_models.ocrd_page import OcrdPage, AlternativeImageType
from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor
from eynollah.model_zoo.model_zoo import EynollahModelZoo
from .sbb_binarize import SbbBinarizer
@ -25,7 +29,7 @@ class SbbBinarizeProcessor(Processor):
# already employs GPU (without singleton process atm)
max_workers = 1
@property
@cached_property
def executable(self):
return 'ocrd-sbb-binarize'
@ -34,8 +38,9 @@ class SbbBinarizeProcessor(Processor):
Set up the model prior to processing.
"""
# resolve relative path via OCR-D ResourceManager
model_path = self.resolve_resource(self.parameter['model'])
self.binarizer = SbbBinarizer(model_dir=model_path, logger=self.logger)
assert isinstance(self.parameter, frozendict)
model_zoo = EynollahModelZoo(basedir=self.parameter['model'])
self.binarizer = SbbBinarizer(model_zoo=model_zoo, logger=self.logger)
def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional[str] = None) -> OcrdPageResult:
"""
@ -98,7 +103,7 @@ class SbbBinarizeProcessor(Processor):
line_image_bin = cv2pil(self.binarizer.run_single(image=pil2cv(line_image), use_patches=True))
# update PAGE (reference the image file):
line_image_ref = AlternativeImageType(comments=line_xywh['features'] + ',binarized')
line.add_AlternativeImage(region_image_ref)
line.add_AlternativeImage(line_image_ref)
result.images.append(OcrdPageResultImage(line_image_bin, line.id + '.IMG-BIN', line_image_ref))
return result

View file

@ -0,0 +1,54 @@
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from tensorflow.keras import layers
projection_dim = 64
patch_size = 1
num_patches =21*21#14*14#28*28#14*14#28*28
class PatchEncoder(layers.Layer):
def __init__(self):
super().__init__()
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
def call(self, patch):
positions = tf.range(start=0, limit=num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config

View file

@ -40,8 +40,8 @@ class EynollahPlotter:
self.image_filename_stem = image_filename_stem
# XXX TODO hacky these cannot be set at init time
self.image_org = image_org
self.scale_x = scale_x
self.scale_y = scale_y
self.scale_x : float = scale_x
self.scale_y : float = scale_y
def save_plot_of_layout_main(self, text_regions_p, image_page):
if self.dir_of_layout is not None:

View file

@ -3,6 +3,8 @@ from typing import Optional
from ocrd_models import OcrdPage
from ocrd import OcrdPageResultImage, Processor, OcrdPageResult
from eynollah.model_zoo.model_zoo import EynollahModelZoo
from .eynollah import Eynollah, EynollahXmlWriter
class EynollahProcessor(Processor):
@ -16,24 +18,20 @@ class EynollahProcessor(Processor):
def setup(self) -> None:
assert self.parameter
if self.parameter['textline_light'] != self.parameter['light_version']:
raise ValueError("Error: You must set or unset both parameter 'textline_light' (to enable light textline detection), "
"and parameter 'light_version' (faster+simpler method for main region detection and deskewing)")
model_zoo = EynollahModelZoo(basedir=self.parameter['models'])
self.eynollah = Eynollah(
self.resolve_resource(self.parameter['models']),
model_zoo=model_zoo,
allow_enhancement=self.parameter['allow_enhancement'],
curved_line=self.parameter['curved_line'],
right2left=self.parameter['right_to_left'],
reading_order_machine_based=self.parameter['reading_order_machine_based'],
ignore_page_extraction=self.parameter['ignore_page_extraction'],
light_version=self.parameter['light_version'],
textline_light=self.parameter['textline_light'],
full_layout=self.parameter['full_layout'],
allow_scaling=self.parameter['allow_scaling'],
headers_off=self.parameter['headers_off'],
tables=self.parameter['tables'],
logger=self.logger
)
self.eynollah.logger = self.logger
self.eynollah.plotter = None
def shutdown(self):
@ -90,7 +88,6 @@ class EynollahProcessor(Processor):
dir_out=None,
image_filename=image_filename,
curved_line=self.eynollah.curved_line,
textline_light=self.eynollah.textline_light,
pcgts=pcgts)
self.eynollah.run_single()
return result

View file

@ -2,20 +2,25 @@
Tool to load model and binarize a given image.
"""
from glob import glob
# pyright: reportIndexIssue=false
# pyright: reportCallIssue=false
# pyright: reportArgumentType=false
# pyright: reportPossiblyUnboundVariable=false
import os
import logging
from PIL import Image
from pathlib import Path
from typing import Optional
import numpy as np
import cv2
from ocrd_utils import tf_disable_interactive_logs
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
from tensorflow.keras.models import load_model
from .model_zoo import EynollahModelZoo
from .utils import is_image_filename
def resize_image(img_in, input_height, input_width):
@ -23,30 +28,24 @@ def resize_image(img_in, input_height, input_width):
class SbbBinarizer:
def __init__(self, model_dir, logger=None):
self.model_dir = model_dir
self.logger = logger if logger else logging.getLogger('SbbBinarizer')
def __init__(
self,
*,
model_zoo: EynollahModelZoo,
logger: Optional[logging.Logger] = None,
):
self.logger = logger if logger else logging.getLogger('eynollah.binarization')
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
self.models = (model_zoo.model_path('binarization'), model_zoo.load_model('binarization'))
self.logger.info('Loaded model %s [%s]', self.models[1], self.models[0])
self.model_files = glob(self.model_dir + "/*/", recursive=True)
self.models = []
for model_file in self.model_files:
self.models.append(self.load_model(model_file))
def load_model(self, model_name):
model = load_model(os.path.join(self.model_dir, model_name), compile=False)
def predict(self, model, img, use_patches, n_batch_inference=5):
model_height = model.layers[len(model.layers)-1].output_shape[1]
model_width = model.layers[len(model.layers)-1].output_shape[2]
n_classes = model.layers[len(model.layers)-1].output_shape[3]
return model, model_height, model_width, n_classes
def predict(self, model_in, img, use_patches, n_batch_inference=5):
model, model_height, model_width, n_classes = model_in
img_org_h = img.shape[0]
img_org_w = img.shape[1]
@ -305,44 +304,57 @@ class SbbBinarizer:
prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:,:,0]
def run(self, image_path=None, output=None, dir_in=None, use_patches=False, overwrite=False):
if dir_in:
ls_imgs = [(os.path.join(dir_in, image_filename),
os.path.join(output, os.path.splitext(image_filename)[0] + '.png'))
for image_filename in filter(is_image_filename,
os.listdir(dir_in))]
def run(self, image=None, image_path=None, output=None, use_patches=False, dir_in=None, overwrite=False):
if not dir_in:
if (image is None) == (image_path is None):
raise ValueError("Must pass either a opencv2 image or an image_path")
if image_path is not None:
image = cv2.imread(image_path)
img_last = self.run_single(image, use_patches)
if output:
if os.path.exists(output):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", output)
else:
self.logger.warning("output file already exists '%s'", output)
return img_last
self.logger.info('Writing binarized image to %s', output)
cv2.imwrite(output, img_last)
return img_last
else:
ls_imgs = [(image_path, output)]
for input_path, output_path in ls_imgs:
print(input_path, 'image_name')
if os.path.exists(output_path):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", output_path)
else:
self.logger.warning("will skip input for existing output file '%s'", output_path)
image = cv2.imread(input_path)
result = self.run_single(image, use_patches)
cv2.imwrite(output_path, result)
ls_imgs = list(filter(is_image_filename, os.listdir(dir_in)))
self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in)
for i, image_path in enumerate(ls_imgs):
image_stem = os.path.splitext(image_path)[0]
output_path = os.path.join(output, image_stem + '.png')
if os.path.exists(output_path):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", output_path)
else:
self.logger.warning("will skip input for existing output file '%s'", output_path)
continue
self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path)
image = cv2.imread(os.path.join(dir_in, image_path))
img_last = self.run_single(image, use_patches)
self.logger.info('Writing binarized image to %s', output_path)
cv2.imwrite(output_path, img_last)
def run_single(self, image: np.ndarray, use_patches=False):
img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.logger.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
model_file, model = self.models
res = self.predict(model, image, use_patches)
res = self.predict(model, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2
res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2
res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
img_fin = img_fin.astype(np.uint8)
img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin
img_fin = img_fin.astype(np.uint8)
img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin
kernel = np.ones((5, 5), np.uint8)
img_last[:, :][img_last[:, :] > 0] = 255

View file

@ -8,6 +8,8 @@ from .build_model_load_pretrained_weights_and_save import build_model_load_pretr
from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli
from .train import ex
from .extract_line_gt import linegt_cli
from .weights_ensembling import main as ensemble_cli
@click.command(context_settings=dict(
ignore_unknown_options=True,
@ -24,3 +26,5 @@ main.add_command(build_model_load_pretrained_weights_and_save)
main.add_command(generate_gt_cli, 'generate-gt')
main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train')
main.add_command(linegt_cli, 'export_textline_images_and_text')
main.add_command(ensemble_cli, 'ensembling')

View file

@ -0,0 +1,134 @@
from logging import Logger, getLogger
from typing import Optional
from pathlib import Path
import os
import click
import cv2
import xml.etree.ElementTree as ET
import numpy as np
from ..utils import is_image_filename
@click.command()
@click.option(
"--image",
"-i",
help="input image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--dir_in",
"-di",
help="directory of input images (instead of --image)",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--dir_xmls",
"-dx",
help="directory of input PAGE-XML files (in addition to --dir_in; filename stems must match the image files, with '.xml' suffix).",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--out",
"-o",
'dir_out',
help="directory for output PAGE-XML files",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--dataset_abbrevation",
"-ds_pref",
'pref_of_dataset',
help="in the case of extracting textline and text from a xml GT file user can add an abbrevation of dataset name to generated dataset",
)
@click.option(
"--do_not_mask_with_textline_contour",
"-nmtc/-mtc",
is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
)
def linegt_cli(
image,
dir_in,
dir_xmls,
dir_out,
pref_of_dataset,
do_not_mask_with_textline_contour,
):
assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both"
if dir_in:
ls_imgs = [
os.path.join(dir_in, image) for image in filter(is_image_filename, os.listdir(dir_in))
]
else:
assert image
ls_imgs = [image]
for dir_img in ls_imgs:
file_name = Path(dir_img).stem
dir_xml = os.path.join(dir_xmls, file_name + '.xml')
img = cv2.imread(dir_img)
total_bb_coordinates = []
tree1 = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8"))
root1 = tree1.getroot()
alltags = [elem.tag for elem in root1.iter()]
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')])
cropped_lines_region_indexer = []
indexer_text_region = 0
indexer_textlines = 0
# FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether
for nn in root1.iter(region_tags):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h = child_textlines.attrib['points'].split(' ')
textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])
x, y, w, h = cv2.boundingRect(textline_coords)
total_bb_coordinates.append([x, y, w, h])
img_poly_on_img = np.copy(img)
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y : y + h, x : x + w, :]
img_crop = img_poly_on_img[y : y + h, x : x + w, :]
if not do_not_mask_with_textline_contour:
img_crop[mask_poly == 0] = 255
if img_crop.shape[0] == 0 or img_crop.shape[1] == 0:
continue
if child_textlines.tag.endswith("TextEquiv"):
for cheild_text in child_textlines:
if cheild_text.tag.endswith("Unicode"):
textline_text = cheild_text.text
if textline_text:
base_name = os.path.join(
dir_out, file_name + '_line_' + str(indexer_textlines)
)
if pref_of_dataset:
base_name += '_' + pref_of_dataset
if not do_not_mask_with_textline_contour:
base_name += '_masked'
with open(base_name + '.txt', 'w') as text_file:
text_file.write(textline_text)
cv2.imwrite(base_name + '.png', img_crop)
indexer_textlines += 1

View file

@ -480,7 +480,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)

View file

@ -18,7 +18,7 @@ with warnings.catch_warnings():
warnings.simplefilter("ignore")
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, img):
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img):
alpha = 0.5
blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255
@ -31,6 +31,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
col_sep = (255, 0, 0)
col_marginal = (106, 90, 205)
col_table = (0, 90, 205)
col_map = (90, 90, 205)
if len(co_image)>0:
cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour
@ -55,6 +56,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
if len(co_table)>0:
cv2.drawContours(blank_image, co_table, -1, col_table, thickness=cv2.FILLED) # Fill the contour
if len(co_map)>0:
cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour
img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB)
@ -234,7 +238,12 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y
con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size )
try:
co_text_eroded.append(con_eroded[0])
if len(con_eroded)>1:
cnt_size = np.array([cv2.contourArea(con_eroded[j]) for j in range(len(con_eroded))])
cnt = contours[np.argmax(cnt_size)]
co_text_eroded.append(cnt)
else:
co_text_eroded.append(con_eroded[0])
except:
co_text_eroded.append(con)
@ -255,6 +264,7 @@ def get_textline_contours_for_visualization(xml_file):
x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
@ -296,6 +306,7 @@ def get_textline_contours_and_ocr_text(xml_file):
x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
@ -365,7 +376,7 @@ def get_layout_contours_for_visualization(xml_file):
link=alltags[0].split('}')[0]+'}'
x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
@ -378,6 +389,7 @@ def get_layout_contours_for_visualization(xml_file):
co_sep=[]
co_img=[]
co_table=[]
co_map=[]
co_noise=[]
types_text = []
@ -594,6 +606,31 @@ def get_layout_contours_for_visualization(xml_file):
elif vv.tag!=link+'Point' and sumi>=1:
break
co_table.append(np.array(c_t_in))
if tag.endswith('}MapRegion') or tag.endswith('}mapregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
for vv in nn.iter():
# check the format of coords
if vv.tag==link+'Coords':
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_map.append(np.array(c_t_in))
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
@ -620,7 +657,7 @@ def get_layout_contours_for_visualization(xml_file):
elif vv.tag!=link+'Point' and sumi>=1:
break
co_noise.append(np.array(c_t_in))
return co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len
def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images):
"""
@ -643,24 +680,21 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
link=alltags[0].split('}')[0]+'}'
x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
if 'columns_width' in list(config_params.keys()):
columns_width_dict = config_params['columns_width']
# FIXME: look in /Page/@custom as well
metadata_element = root1.find(link+'Metadata')
comment_is_sub_element = False
num_col = None
for child in metadata_element:
tag2 = child.tag
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
text_comments = child.text
num_col = int(text_comments.split('num_col')[1])
comment_is_sub_element = True
if not comment_is_sub_element:
# FIXME: look in /Page/@custom as well
num_col = None
if num_col:
x_new = columns_width_dict[str(num_col)]
@ -812,7 +846,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
types_graphic_label = list(types_graphic_dict.values())
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)]
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)]
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
@ -823,6 +857,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
co_sep=[]
co_img=[]
co_table=[]
co_map=[]
co_noise=[]
for tag in region_tags:
@ -1033,6 +1068,32 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
elif vv.tag!=link+'Point' and sumi>=1:
break
co_table.append(np.array(c_t_in))
if 'mapregion' in keys:
if tag.endswith('}MapRegion') or tag.endswith('}mapregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
for vv in nn.iter():
# check the format of coords
if vv.tag==link+'Coords':
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_map.append(np.array(c_t_in))
if 'noiseregion' in keys:
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
@ -1106,6 +1167,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
erosion_rate = 0#2
dilation_rate = 3#4
co_table, img_boundary = update_region_contours(co_table, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
if "mapregion" in elements_with_artificial_class:
erosion_rate = 0#2
dilation_rate = 3#4
co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
@ -1131,6 +1196,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']])
if 'tableregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']])
if 'mapregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']])
if 'noiseregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']])
@ -1192,6 +1259,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if 'tableregion' in keys:
color_label = config_params['tableregion']
img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label))
if 'mapregion' in keys:
color_label = config_params['mapregion']
img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label))
if 'noiseregion' in keys:
color_label = config_params['noiseregion']
img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label))
@ -1690,15 +1760,15 @@ def read_xml(xml_file):
index_tot_regions,
img_poly)
def bounding_box(cnt,color, corr_order_index ):
x, y, w, h = cv2.boundingRect(cnt)
x = int(x*scale_w)
y = int(y*scale_h)
w = int(w*scale_w)
h = int(h*scale_h)
return [x,y,w,h,int(color), int(corr_order_index)+1]
# def bounding_box(cnt,color, corr_order_index ):
# x, y, w, h = cv2.boundingRect(cnt)
# x = int(x*scale_w)
# y = int(y*scale_h)
#
# w = int(w*scale_w)
# h = int(h*scale_h)
#
# return [x,y,w,h,int(color), int(corr_order_index)+1]
def resize_image(seg_in,input_height,input_width):
return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST)

View file

@ -4,17 +4,19 @@ Tool to load model and predict for given image.
import sys
import os
from typing import Tuple
import warnings
import json
import click
import numpy as np
from numpy._typing import NDArray
import cv2
import xml.etree.ElementTree as ET
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from tensorflow.keras.models import load_model
import xml.etree.ElementTree as ET
from tensorflow.keras.models import Model, load_model
from .gt_gen_utils import (
filter_contours_area_of_image,
@ -32,6 +34,9 @@ from .metrics import (
weighted_categorical_crossentropy,
)
from.utils import (scale_padd_image_for_ocr)
from eynollah.utils.utils_ocr import (decode_batch_predictions)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@ -47,9 +52,10 @@ class SBBPredict:
save_layout,
ground_truth,
xml_file,
cpu,
out,
min_area):
min_area,
):
self.image=image
self.dir_in=dir_in
self.patches=patches
@ -61,6 +67,7 @@ class SBBPredict:
self.config_params_model=config_params_model
self.xml_file = xml_file
self.out = out
self.cpu = cpu
if min_area:
self.min_area = float(min_area)
else:
@ -111,30 +118,35 @@ class SBBPredict:
return mIoU
def start_new_session_and_model(self):
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
print("no GPU device available", file=sys.stderr)
if self.cpu:
tf.config.set_visible_devices([], 'GPU')
else:
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
print("no GPU device available", file=sys.stderr)
#tensorflow.keras.layers.custom_layer = PatchEncoder
#tensorflow.keras.layers.custom_layer = Patches
self.model = load_model(self.model_dir, compile=False,
custom_objects={"PatchEncoder": PatchEncoder,
"Patches": Patches})
#keras.losses.custom_loss = weighted_categorical_crossentropy
#self.model = load_model(self.model_dir, compile=False)
if self.task == "cnn-rnn-ocr":
self.model = Model(
self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output)
else:
self.model = load_model(self.model_dir, compile=False,
custom_objects={"PatchEncoder": PatchEncoder,
"Patches": Patches})
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order':
last = self.model.layers[-1]
self.img_height = last.output_shape[1]
self.img_width = last.output_shape[2]
self.n_classes = last.output_shape[3]
def visualize_model_output(self, prediction, img, task):
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization":
prediction = prediction * -1
prediction = prediction + 1
@ -173,9 +185,12 @@ class SBBPredict:
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
assert isinstance(added_image, np.ndarray)
assert isinstance(layout_only, np.ndarray)
return added_image, layout_only
def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name']
img_1ch = cv2.imread(image_dir, 0) / 255.0
@ -187,11 +202,35 @@ class SBBPredict:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model.predict(img_in, verbose=0)
label_p_pred = self.model.predict(img_in, verbose='0')
index_class = np.argmax(label_p_pred[0])
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
elif self.task == "cnn-rnn-ocr":
img=cv2.imread(image_dir)
img = scale_padd_image_for_ocr(img, self.config_params_model['input_height'], self.config_params_model['input_width'])
img = img / 255.
with open(os.path.join(self.model_dir, "characters_org.txt"), 'r') as char_txt_f:
characters = json.load(char_txt_f)
AUTOTUNE = tf.data.AUTOTUNE
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
preds = self.model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]), verbose=0)
pred_texts = decode_batch_predictions(preds, num_to_char)
pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts
elif self.task == 'reading_order':
img_height = self.config_params_model['input_height']
img_width = self.config_params_model['input_width']
@ -311,7 +350,7 @@ class SBBPredict:
#input_1[:,:,1] = img3[:,:,0]/5.
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
y_pr = self.model.predict(input_1 , verbose=0)
y_pr = self.model.predict(input_1 , verbose='0')
scalibility_num = scalibility_num+1
if batch_counter==inference_bs:
@ -345,6 +384,7 @@ class SBBPredict:
name_space = name_space.split('{')[1]
page_element = root_xml.find(link+'Page')
assert isinstance(page_element, ET.Element)
"""
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
@ -439,7 +479,7 @@ class SBBPredict:
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
verbose=0)
verbose='0')
if self.task == 'enhancement':
seg = label_p_pred[0, :, :, :]
@ -447,6 +487,8 @@ class SBBPredict:
elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
else:
raise ValueError(f"Unhandled task {self.task}")
if i == 0 and j == 0:
@ -501,6 +543,8 @@ class SBBPredict:
elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
else:
raise ValueError(f"Unhandled task {self.task}")
prediction_true = seg.astype(int)
@ -519,6 +563,8 @@ class SBBPredict:
elif self.task == 'enhancement':
if self.save:
cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}")
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save:
@ -526,9 +572,9 @@ class SBBPredict:
if self.save_layout:
cv2.imwrite(self.save_layout, only_layout)
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
else:
ls_images = os.listdir(self.dir_in)
@ -542,6 +588,8 @@ class SBBPredict:
elif self.task == 'enhancement':
self.save = os.path.join(self.out, f_name+'.png')
cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr":
print(f"Detected text for file name {f_name} is: {res}")
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
self.save = os.path.join(self.out, f_name+'_overlayed.png')
@ -549,9 +597,9 @@ class SBBPredict:
self.save_layout = os.path.join(self.out, f_name+'_layout.png')
cv2.imwrite(self.save_layout, only_layout)
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
@ -607,22 +655,27 @@ class SBBPredict:
"-xml",
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
)
@click.option(
"--cpu",
"-cpu",
help="For OCR, the default device is the GPU. If this parameter is set to true, inference will be performed on the CPU",
is_flag=True,
)
@click.option(
"--min_area",
"-min",
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
)
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di input is required"
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = config_params_model['task']
if task != 'classification' and task != 'reading_order':
if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]:
assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
x = SBBPredict(image, dir_in, model, task, config_params_model,
patches, save, save_layout, ground_truth, xml_file, out,
min_area)
patches, save, save_layout, ground_truth, xml_file,
cpu, out, min_area)
x.run()

View file

@ -147,6 +147,7 @@ def generalized_dice_loss(y_true, y_pred):
return 1 - generalized_dice_coeff2(y_true, y_pred)
# TODO: document where this is from
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
"""
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
@ -175,6 +176,7 @@ def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
# TODO: document where this is from
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
verbose=False):
"""
@ -267,6 +269,8 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=T
return K.mean(non_zero_sum / non_zero_count)
# TODO: document where this is from
# TODO: Why a different implementation than IoU from utils?
def mean_iou(y_true, y_pred, **kwargs):
"""
Compute mean Intersection over Union of two segmentation masks, via Keras.
@ -311,6 +315,7 @@ def iou_vahid(y_true, y_pred):
return K.mean(iou)
# TODO: copy from utils?
def IoU_metric(Yi, y_predi):
# mean Intersection over Union
# Mean IoU = TP/(FN + TP + FP)
@ -337,6 +342,7 @@ def IoU_metric_keras(y_true, y_pred):
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
# TODO: unused, remove?
def jaccard_distance_loss(y_true, y_pred, smooth=100):
"""
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)

View file

@ -2,12 +2,36 @@ import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import layers
from tensorflow.keras.layers import (
Activation,
Add,
AveragePooling2D,
BatchNormalization,
Bidirectional,
Conv1D,
Conv2D,
Dense,
Dropout,
Embedding,
Flatten,
Input,
Lambda,
Layer,
LayerNormalization,
LSTM,
MaxPooling2D,
MultiHeadAttention,
Reshape,
UpSampling2D,
ZeroPadding2D,
add,
concatenate
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from eynollah.patch_encoder import Patches, PatchEncoder
##mlp_head_units = [512, 256]#[2048, 1024]
###projection_dim = 64
##transformer_layers = 2#8
@ -19,96 +43,34 @@ RESNET50_WEIGHTS_URL = ('https://github.com/fchollet/deep-learning-models/releas
IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1
class CTCLayer(Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = tf.keras.backend.ctc_batch_cost
def call(self, y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
# At test time, just return the computed predictions.
return y_pred
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
x = Dense(units, activation=tf.nn.gelu)(x)
x = Dropout(dropout_rate)(x)
return x
class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class Patches_old(layers.Layer):
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
#print(patches.shape,patch_dims,'patch_dims')
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def one_side_pad(x):
# rs: fixme: lambda layers are problematic for de/serialization!
# - can we use ZeroPadding1D instead of ZeroPadding2D+Lambda?
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
if IMAGE_ORDERING == 'channels_first':
x = Lambda(lambda x: x[:, :, :-1, :-1])(x)
@ -150,7 +112,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
x = layers.add([x, input_tensor])
x = add([x, input_tensor])
x = Activation('relu')(x)
return x
@ -195,12 +157,12 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
name=conv_name_base + '1')(input_tensor)
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
x = layers.add([x, shortcut])
x = add([x, shortcut])
x = Activation('relu')(x)
return x
def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False):
def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
assert input_height % 32 == 0
assert input_width % 32 == 0
@ -415,7 +377,7 @@ def vit_resnet50_unet(num_patches,
pretraining=False):
if transformer_mlp_head_units is None:
transformer_mlp_head_units = [128, 64]
inputs = layers.Input(shape=(input_height, input_width, 3))
inputs = Input(shape=(input_height, input_width, 3))
#transformer_units = [
#projection_dim * 2,
@ -460,27 +422,35 @@ def vit_resnet50_unet(num_patches,
model = Model(inputs, x).load_weights(RESNET50_WEIGHTS_PATH)
#num_patches = x.shape[1]*x.shape[2]
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x)
# rs: fixme patch size not configurable anymore...
#patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
patches = Patches()(x)
assert transformer_patchsize_x == transformer_patchsize_y == 1
# Encode patches.
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
# rs: fixme num patches and dim not configurable anymore...
#encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
encoded_patches = PatchEncoder()(patches)
assert num_patches == 21 * 21
assert transformer_projection_dim == 64
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
attention_output = MultiHeadAttention(
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
x2 = Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
encoded_patches = Add()([x3, x2])
assert isinstance(x, Layer)
encoded_patches = tf.reshape(encoded_patches,
[-1, x.shape[1], x.shape[2],
transformer_projection_dim // (transformer_patchsize_x *
@ -551,7 +521,7 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches,
pretraining=False):
if transformer_mlp_head_units is None:
transformer_mlp_head_units = [128, 64]
inputs = layers.Input(shape=(input_height, input_width, 3))
inputs = Input(shape=(input_height, input_width, 3))
##transformer_units = [
##projection_dim * 2,
@ -560,25 +530,32 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches,
IMAGE_ORDERING = 'channels_last'
bn_axis=3
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
# rs: fixme patch size not configurable anymore...
#patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
patches = Patches()(inputs)
assert transformer_patchsize_x == transformer_patchsize_y == 1
# Encode patches.
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
# rs: fixme num patches and dim not configurable anymore...
#encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
encoded_patches = PatchEncoder()(patches)
assert num_patches == 21 * 21
assert transformer_projection_dim == 64
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
attention_output = MultiHeadAttention(
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
x2 = Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
encoded_patches = Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches,
[-1,
@ -734,9 +711,6 @@ def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=
x = Dense(n_classes, activation='softmax', name='fc1000')(x)
model = Model(img_input, x)
return model
def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
@ -793,3 +767,81 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
model = Model(img_input , o)
return model
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
input_img = Input(shape=(image_height, image_width, 3), name="image")
labels = Input(name="label", shape=(None,))
x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
x = BatchNormalization(name="bn1")(x)
x = Activation("relu", name="relu1")(x)
x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
x = BatchNormalization(name="bn2")(x)
x = Activation("relu", name="relu2")(x)
x = MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x = Conv2D(128,kernel_size=(3,3),padding="same")(x)
x = BatchNormalization(name="bn3")(x)
x = Activation("relu", name="relu3")(x)
x = Conv2D(128,kernel_size=(3,3),padding="same")(x)
x = BatchNormalization(name="bn4")(x)
x = Activation("relu", name="relu4")(x)
x = MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x = Conv2D(256,kernel_size=(3,3),padding="same")(x)
x = BatchNormalization(name="bn5")(x)
x = Activation("relu", name="relu5")(x)
x = Conv2D(256,kernel_size=(3,3),padding="same")(x)
x = BatchNormalization(name="bn6")(x)
x = Activation("relu", name="relu6")(x)
x = MaxPool2D(pool_size=(2,2),strides=(2,2))(x)
x = Conv2D(image_width,kernel_size=(3,3),padding="same")(x)
x = BatchNormalization(name="bn7")(x)
x = Activation("relu", name="relu7")(x)
x = Conv2D(image_width,kernel_size=(16,1))(x)
x = BatchNormalization(name="bn8")(x)
x = Activation("relu", name="relu8")(x)
x2d = MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x4d = MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d)
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
x = Reshape(target_shape=new_shape, name="reshape")(x)
x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d)
x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d)
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
out = BatchNormalization(name="bn9")(out)
out = Activation("relu", name="relu9")(out)
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
out = Dense(n_classes, activation="softmax", name="dense2")(out)
# Add CTC layer for calculating CTC loss at each step.
output = CTCLayer(name="ctc_loss")(labels, out)
model = Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer")
return model

View file

@ -17,6 +17,7 @@ from eynollah.training.models import (
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
)
@ -25,7 +26,6 @@ from eynollah.training.utils import (
generate_arrays_from_folder_reading_order,
get_one_hot,
preprocess_imgs,
return_number_of_total_training_data
)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@ -35,11 +35,10 @@ from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.metrics import MeanIoU, F1Score
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.layers import StringLookup
from tensorflow.keras.utils import image_dataset_from_directory
from sacred import Experiment
from sacred.config import create_captured_function
from tqdm import tqdm
from sklearn.metrics import f1_score
import numpy as np
import cv2
@ -68,6 +67,7 @@ class SaveWeightsAfterSteps(ModelCheckpoint):
json.dump(self._config, fp) # encode dict into JSON
def configuration():
try:
for device in tf.config.list_physical_devices('GPU'):
@ -111,6 +111,9 @@ def config_params():
n_classes = None # Number of classes. In the case of binary classification this should be 2.
n_epochs = 1 # Number of epochs to train.
n_batch = 1 # Number of images per batch at each iteration. (Try as large as fits on VRAM.)
if task == 'cnn-rnn-ocr':
max_len = None # Maximum sequence length (characters per line) for OCR output.
characters_txt_file = None # Path of JSON file defining character set needed of OCR model.
input_height = 224 * 1 # Height of model's input in pixels.
input_width = 224 * 1 # Width of model's input in pixels.
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
@ -124,47 +127,74 @@ def config_params():
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
if augmentation:
flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json.
flip_aug = False # Whether different types of flipping will be applied to the image. Requires "flip_index" setting.
if flip_aug:
flip_index = None # Flip image for augmentation.
blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json.
flip_index = None # List of codes (as in cv2.flip) for flip augmentation.
blur_aug = False # Whether images will be blurred. Requires "blur_k" setting.
if blur_aug:
blur_k = None # Blur image for augmentation.
blur_k = None # Method of blurring (gauss, median or blur).
padding_white = False # If true, white padding will be applied to the image.
if padding_white and task == 'cnn-rnn-ocr':
white_padds = None # List of padding sizes.
padd_colors = None # List of padding colors, but only "white" or "black" or both.
padding_black = False # If true, black padding will be applied to the image.
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json.
scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image.
scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image.
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
scaling = False # Whether images will be scaled up or down. Requires "scales" setting.
scaling_bluring = False # Whether a combination of scaling and blurring will be applied to the image.
scaling_binarization = False # Whether a combination of scaling and binarization will be applied to the image.
scaling_brightness = False # Whether a combination of scaling and brightening will be applied to the image.
scaling_flip = False # Whether a combination of scaling and flipping will be applied to the image.
if scaling or scaling_brightness or scaling_bluring or scaling_binarization or scaling_flip:
scales = None # Scale patches for augmentation.
shifting = False
degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json.
if degrading:
degrade_scales = None # Degrade image for augmentation.
brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json.
brightening = False # Whether images will be brightened. Requires "brightness" setting.
if brightening:
brightness = None # Brighten image for augmentation.
binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images.
brightness = None # List of intensity factors for brightening.
binarization = False # Whether binary images will be used, too. (Will use Otsu thresholding unless supplying precomputed images in "dir_img_bin".)
if binarization:
dir_img_bin = None # Directory of training dataset subdirectory of binarized images
add_red_textlines = False
adding_rgb_background = False
adding_rgb_background = False # Whether texture images will be added as artificial background.
if adding_rgb_background:
dir_rgb_backgrounds = None # Directory of texture images for synthetic background
adding_rgb_foreground = False
adding_rgb_foreground = False # Whether texture images will be added as artificial foreground.
if adding_rgb_foreground:
dir_rgb_foregrounds = None # Directory of texture images for synthetic foreground
if adding_rgb_background or adding_rgb_foreground:
number_of_backgrounds_per_image = 1
if task == 'cnn-rnn-ocr':
image_inversion = False # Whether the binarized images will be inverted.
textline_skewing_bin = False # Whether binarized textline images will be rotated.
textline_left_in_depth_bin = False # Whether left side of binary textline image will be displayed in depth.
textline_right_in_depth_bin = False # Whether right side of binary textline image will be displayed in depth.
textline_up_in_depth_bin = False # Whether upper side of binary textline image will be displayed in depth.
textline_down_in_depth_bin = False # Whether lower side of binary textline image will be displayed in depth.
pepper_bin_aug = False # Whether pepper noise will be added to binary textline images.
bin_deg = False # Whether a combination of degrading and binarization will be applied to the image.
degrading = False # Whether images will be artificially degraded. Requires the "degrade_scales" setting.
if degrading or binarization and task == 'cnn-rnn-ocr' and bin_deg:
degrade_scales = None # List of quality factors for degradation.
channels_shuffling = False # Re-arrange color channels.
if channels_shuffling:
shuffle_indexes = None # Which channels to switch between.
rotation = False # If true, a 90 degree rotation will be implemeneted.
rotation_not_90 = False # If true rotation based on provided angles with thetha will be implemeneted.
shuffle_indexes = None # List of channels to switch between.
rotation = False # Whether images will be rotated by 90 degrees.
rotation_not_90 = False # Whether images will be rotated arbitrarily (skewed). Requires "thetha" setting.
if rotation_not_90:
thetha = None # Rotate image by these angles for augmentation.
thetha = None # List of rotation angles in degrees.
if task == 'cnn-rnn-ocr':
white_noise_strap = False # Whether white noise will be applied on some straps on the textline image.
textline_skewing = False # Whether textline images will be skewed for augmentation.
if textline_skewing or binarization and textline_skewing_bin:
skewing_amplitudes = None # List of skewing angles in degrees like [5, 8]
textline_left_in_depth = False # If true, left side of textline image will be displayed in depth.
textline_right_in_depth = False # If true, right side of textline image will be displayed in depth.
textline_up_in_depth = False # If true, upper side of textline image will be displayed in depth.
textline_down_in_depth = False # If true, lower side of textline image will be displayed in depth.
pepper_aug = False # Whether pepper noise will be added to textline images.
if pepper_aug or binarization and pepper_bin_aug:
pepper_indexes = None # List of pepper noise factors, e.g. [0.01, 0.005].
color_padding_rotation = False # Whether images will be rotated with color padding. Requires "thetha_padd" setting.
if color_padding_rotation:
thetha_padd = None # List of angles (in degrees) used for rotation alongside padding.
dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels".
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
@ -197,12 +227,15 @@ def run(_config,
augmentation,
# dependent config keys need a default,
# otherwise yields sacred.utils.ConfigAddedError
## if rotation_not_90
thetha=None,
is_loss_soft_dice=False,
weighted_loss=False,
## if continue_training
index_start=0,
dir_of_start_model=None,
backbone_type=None,
## if backbone_type=transformer
transformer_projection_dim=None,
transformer_mlp_head_units=None,
transformer_layers=None,
@ -211,8 +244,33 @@ def run(_config,
transformer_patchsize_x=None,
transformer_patchsize_y=None,
transformer_num_patches_xy=None,
## if task=classification
f1_threshold_classification=None,
classification_classes_name=None,
## if task=cnn-rnn-ocr
characters_txt_file=None,
color_padding_rotation=False,
thetha_padd=None,
bin_deg=False,
image_inversion=False,
white_noise_strap=False,
textline_skewing=False,
textline_skewing_bin=False,
textline_left_in_depth=False,
textline_left_in_depth_bin=False,
textline_right_in_depth=False,
textline_right_in_depth_bin=False,
textline_up_in_depth=False,
textline_up_in_depth_bin=False,
textline_down_in_depth=False,
textline_down_in_depth_bin=False,
pepper_aug=False,
pepper_bin_aug=False,
pepper_indexes=None,
padd_colors=None,
white_padds=None,
skewing_amplitudes=None,
max_len=None,
):
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
@ -252,11 +310,11 @@ def run(_config,
dir_img, dir_seg = get_dirs_or_files(dir_train)
dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval)
imgs_list=np.array(os.listdir(dir_img))
segs_list=np.array(os.listdir(dir_seg))
imgs_list = list(os.listdir(dir_img))
segs_list = list(os.listdir(dir_seg))
imgs_list_test=np.array(os.listdir(dir_img_val))
segs_list_test=np.array(os.listdir(dir_seg_val))
imgs_list_test = list(os.listdir(dir_img_val))
segs_list_test = list(os.listdir(dir_seg_val))
# writing patches into a sub-folder in order to be flowed from directory.
preprocess_imgs(_config,
@ -356,6 +414,7 @@ def run(_config,
model_builder.logger = _log
model = model_builder(num_patches)
assert model is not None
#if you want to see the model structure just uncomment model summary.
#model.summary()
@ -412,7 +471,80 @@ def run(_config,
#os.system('rm -rf '+dir_eval_flowing)
#model.save(dir_output+'/'+'model'+'.h5')
elif task=="cnn-rnn-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train)
dir_img_val, dir_lab_val = get_dirs_or_files(dir_eval)
imgs_list = list(os.listdir(dir_img))
labs_list = list(os.listdir(dir_lab))
imgs_list_val = list(os.listdir(dir_img_val))
labs_list_val = list(os.listdir(dir_lab_val))
with open(characters_txt_file, 'r') as char_txt_f:
characters = json.load(char_txt_f)
padding_token = len(characters) + 5
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
##num_to_char = StringLookup(
##vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
##)
n_classes = len(char_to_num.get_vocabulary()) + 2
if continue_training:
model = load_model(dir_of_start_model)
else:
index_start = 0
model = cnn_rnn_ocr_model(image_height=input_height,
image_width=input_width,
n_classes=n_classes,
max_seq=max_len)
#print(model.summary())
# todo: use Dataset.map() on Dataset.list_files()
# todo: test_ds
def gen():
return preprocess_imgs(_config,
imgs_list,
labs_list,
dir_img,
dir_lab,
None, # no file I/O, but in-memory
None, # no file I/O, but in-memory
# extra+overrides
char_to_num=char_to_num,
padding_token=padding_token
)
train_ds = tf.data.Dataset.from_generator(gen)
train_ds = train_ds.padded_batch(n_batch,
padded_shapes=([image_height, image_width, 3], [None]),
padding_values=(0, padding_token),
drop_remainder=True,
#num_parallel_calls=tf.data.AUTOTUNE,
)
train_ds = train_ds.repeat().shuffle().prefetch(20)
#initial_learning_rate = 1e-4
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
#alpha = 0.01
#lr_schedule = 1e-4
#tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha)
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config)]
if save_interval:
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
model.fit(
train_ds,
#validation_data=test_ds,
epochs=n_epochs,
callbacks=callbacks,
initial_epoch=index_start)
elif task=='classification':
if continue_training:
model = load_model(dir_of_start_model, compile=False)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,136 @@
import sys
from glob import glob
from os import environ, devnull
from os.path import join
from warnings import catch_warnings, simplefilter
import os
import numpy as np
from PIL import Image
import cv2
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(devnull, 'w')
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend
sys.stderr = stderr
from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
import click
import logging
class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y):
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def start_new_session():
###config = tf.compat.v1.ConfigProto()
###config.gpu_options.allow_growth = True
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
###tensorflow_backend.set_session(self.session)
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
return session
def run_ensembling(dir_models, out):
ls_models = os.listdir(dir_models)
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
@click.command()
@click.option(
"--dir_models",
"-dm",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--out",
"-o",
help="output directory where ensembled model will be written.",
type=click.Path(exists=False, file_okay=False),
)
def main(dir_models, out):
run_ensembling(dir_models, out)

View file

@ -241,7 +241,14 @@ def find_num_col_deskew(regions_without_separators, sigma_, multiplier=3.8):
z = gaussian_filter1d(regions_without_separators_0, sigma_)
return np.std(z)
def find_num_col(regions_without_separators, num_col_classifier, tables, multiplier=3.8, unbalanced=False, vertical_separators=None):
def find_num_col(
regions_without_separators,
num_col_classifier,
tables,
multiplier=3.8,
unbalanced=False,
vertical_separators=None
):
if not regions_without_separators.any():
return 0, []
if vertical_separators is None:

View file

@ -356,7 +356,7 @@ def join_polygons(polygons: Sequence[Polygon], scale=20) -> Polygon:
assert jointp.geom_type == 'Polygon', jointp.wkt
# follow-up calculations will necessarily be integer;
# so anticipate rounding here and then ensure validity
jointp2 = set_precision(jointp, 1.0)
jointp2 = set_precision(jointp, 1.0, mode="keep_collapsed")
if jointp2.geom_type != 'Polygon' or not jointp2.is_valid:
jointp2 = Polygon(np.round(jointp.exterior.coords))
jointp2 = make_valid(jointp2)

View file

@ -19,7 +19,6 @@ def adhere_drop_capital_region_into_corresponding_textline(
all_found_textline_polygons_h,
kernel=None,
curved_line=False,
textline_light=False,
):
# print(np.shape(all_found_textline_polygons),np.shape(all_found_textline_polygons[3]),'all_found_textline_polygonsshape')
# print(all_found_textline_polygons[3])
@ -79,7 +78,7 @@ def adhere_drop_capital_region_into_corresponding_textline(
# region_with_intersected_drop=region_with_intersected_drop/3
region_with_intersected_drop = region_with_intersected_drop.astype(np.uint8)
# print(np.unique(img_con_all_copy[:,:,0]))
if curved_line or textline_light:
if curved_line:
if len(region_with_intersected_drop) > 1:
sum_pixels_of_intersection = []

View file

@ -0,0 +1,16 @@
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
import sys
from PIL import ImageFont
if sys.version_info < (3, 10):
import importlib_resources
else:
import importlib.resources as importlib_resources
def get_font():
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40)

View file

@ -6,7 +6,7 @@ from .contour import find_new_features_of_contours, return_contours_of_intereste
from .resize import resize_image
from .rotate import rotate_image
def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_version=False, kernel=None):
def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=None):
mask_marginals=np.zeros((text_with_lines.shape[0],text_with_lines.shape[1]))
mask_marginals=mask_marginals.astype(np.uint8)
@ -27,9 +27,8 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
text_with_lines=resize_image(text_with_lines,text_with_lines_eroded.shape[0],text_with_lines_eroded.shape[1])
if light_version:
kernel_hor = np.ones((1, 5), dtype=np.uint8)
text_with_lines = cv2.erode(text_with_lines,kernel_hor,iterations=6)
kernel_hor = np.ones((1, 5), dtype=np.uint8)
text_with_lines = cv2.erode(text_with_lines,kernel_hor,iterations=6)
text_with_lines_y=text_with_lines.sum(axis=0)
text_with_lines_y_eroded=text_with_lines_eroded.sum(axis=0)
@ -43,10 +42,7 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
elif thickness_along_y_percent>=30 and thickness_along_y_percent<50:
min_textline_thickness=20
else:
if light_version:
min_textline_thickness=45
else:
min_textline_thickness=40
min_textline_thickness=45
if thickness_along_y_percent>=14:
@ -128,92 +124,39 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
if max_point_of_right_marginal>=text_regions.shape[1]:
max_point_of_right_marginal=text_regions.shape[1]-1
if light_version:
text_regions_org = np.copy(text_regions)
text_regions[text_regions[:,:]==1]=4
pixel_img=4
min_area_text=0.00001
polygon_mask_marginals_rotated = return_contours_of_interested_region(mask_marginals,1,min_area_text)
polygon_mask_marginals_rotated = polygon_mask_marginals_rotated[0]
text_regions_org = np.copy(text_regions)
text_regions[text_regions[:,:]==1]=4
pixel_img=4
min_area_text=0.00001
polygon_mask_marginals_rotated = return_contours_of_interested_region(mask_marginals,1,min_area_text)
polygon_mask_marginals_rotated = polygon_mask_marginals_rotated[0]
polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text)
polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text)
cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals)
cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals)
text_regions[(text_regions[:,:]==4)]=1
text_regions[(text_regions[:,:]==4)]=1
marginlas_should_be_main_text=[]
marginlas_should_be_main_text=[]
x_min_marginals_left=[]
x_min_marginals_right=[]
x_min_marginals_left=[]
x_min_marginals_right=[]
for i in range(len(cx_text_only)):
results = cv2.pointPolygonTest(polygon_mask_marginals_rotated, (cx_text_only[i], cy_text_only[i]), False)
for i in range(len(cx_text_only)):
results = cv2.pointPolygonTest(polygon_mask_marginals_rotated, (cx_text_only[i], cy_text_only[i]), False)
if results == -1:
marginlas_should_be_main_text.append(polygons_of_marginals[i])
if results == -1:
marginlas_should_be_main_text.append(polygons_of_marginals[i])
text_regions_org=cv2.fillPoly(text_regions_org, pts =marginlas_should_be_main_text, color=(4,4))
text_regions = np.copy(text_regions_org)
text_regions_org=cv2.fillPoly(text_regions_org, pts =marginlas_should_be_main_text, color=(4,4))
text_regions = np.copy(text_regions_org)
else:
text_regions[(mask_marginals_rotated[:,:]!=1) & (text_regions[:,:]==1)]=4
pixel_img=4
min_area_text=0.00001
polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text)
cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals)
text_regions[(text_regions[:,:]==4)]=1
marginlas_should_be_main_text=[]
x_min_marginals_left=[]
x_min_marginals_right=[]
for i in range(len(cx_text_only)):
x_width_mar=abs(x_min_text_only[i]-x_max_text_only[i])
y_height_mar=abs(y_min_text_only[i]-y_max_text_only[i])
if x_width_mar>16 and y_height_mar/x_width_mar<18:
marginlas_should_be_main_text.append(polygons_of_marginals[i])
if x_min_text_only[i]<(mid_point-one_third_left):
x_min_marginals_left_new=x_min_text_only[i]
if len(x_min_marginals_left)==0:
x_min_marginals_left.append(x_min_marginals_left_new)
else:
x_min_marginals_left[0]=min(x_min_marginals_left[0],x_min_marginals_left_new)
else:
x_min_marginals_right_new=x_min_text_only[i]
if len(x_min_marginals_right)==0:
x_min_marginals_right.append(x_min_marginals_right_new)
else:
x_min_marginals_right[0]=min(x_min_marginals_right[0],x_min_marginals_right_new)
if len(x_min_marginals_left)==0:
x_min_marginals_left=[0]
if len(x_min_marginals_right)==0:
x_min_marginals_right=[text_regions.shape[1]-1]
text_regions=cv2.fillPoly(text_regions, pts =marginlas_should_be_main_text, color=(4,4))
#text_regions[:,:int(x_min_marginals_left[0])][text_regions[:,:int(x_min_marginals_left[0])]==1]=0
#text_regions[:,int(x_min_marginals_right[0]):][text_regions[:,int(x_min_marginals_right[0]):]==1]=0
text_regions[:,:int(min_point_of_left_marginal)][text_regions[:,:int(min_point_of_left_marginal)]==1]=0
text_regions[:,int(max_point_of_right_marginal):][text_regions[:,int(max_point_of_right_marginal):]==1]=0
###text_regions[:,0:point_left][text_regions[:,0:point_left]==1]=4

View file

@ -5,8 +5,6 @@ import numpy as np
import cv2
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
from multiprocessing import Process, Queue, cpu_count
from multiprocessing import Pool
from .rotate import rotate_image
from .resize import resize_image
from .contour import (
@ -20,9 +18,7 @@ from .contour import (
from .shm import share_ndarray, wrap_ndarray_shared
from . import (
find_num_col_deskew,
crop_image_inside_box,
box2rect,
box2slice,
)
def dedup_separate_lines(img_patch, contour_text_interest, thetha, axis):
@ -1593,65 +1589,6 @@ def get_smallest_skew(img, sigma_des, angles, logger=None, plotter=None, map=map
var = 0
return angle, var
@wrap_ndarray_shared(kw='textline_mask_tot_ea')
def do_work_of_slopes_new(
box_text, contour, contour_par,
textline_mask_tot_ea=None, slope_deskew=0.0,
logger=None, MAX_SLOPE=999, KERNEL=None, plotter=None
):
if KERNEL is None:
KERNEL = np.ones((5, 5), np.uint8)
if logger is None:
logger = getLogger(__package__)
logger.debug('enter do_work_of_slopes_new')
x, y, w, h = box_text
crop_coor = box2rect(box_text)
mask_textline = np.zeros(textline_mask_tot_ea.shape)
mask_textline = cv2.fillPoly(mask_textline, pts=[contour], color=(1,1,1))
all_text_region_raw = textline_mask_tot_ea * mask_textline
all_text_region_raw = all_text_region_raw[y: y + h, x: x + w].astype(np.uint8)
img_int_p = all_text_region_raw[:,:]
img_int_p = cv2.erode(img_int_p, KERNEL, iterations=2)
if not np.prod(img_int_p.shape) or img_int_p.shape[0] /img_int_p.shape[1] < 0.1:
slope = 0
slope_for_all = slope_deskew
all_text_region_raw = textline_mask_tot_ea[y: y + h, x: x + w]
cnt_clean_rot = textline_contours_postprocessing(all_text_region_raw, slope_for_all, contour_par, box_text, 0)
else:
try:
textline_con, hierarchy = return_contours_of_image(img_int_p)
textline_con_fil = filter_contours_area_of_image(img_int_p, textline_con,
hierarchy,
max_area=1, min_area=0.00008)
y_diff_mean = find_contours_mean_y_diff(textline_con_fil) if len(textline_con_fil) > 1 else np.NaN
if np.isnan(y_diff_mean):
slope_for_all = MAX_SLOPE
else:
sigma_des = max(1, int(y_diff_mean * (4.0 / 40.0)))
img_int_p[img_int_p > 0] = 1
slope_for_all = return_deskew_slop(img_int_p, sigma_des, logger=logger, plotter=plotter)
if abs(slope_for_all) <= 0.5:
slope_for_all = slope_deskew
except:
logger.exception("cannot determine angle of contours")
slope_for_all = MAX_SLOPE
if slope_for_all == MAX_SLOPE:
slope_for_all = slope_deskew
slope = slope_for_all
mask_only_con_region = np.zeros(textline_mask_tot_ea.shape)
mask_only_con_region = cv2.fillPoly(mask_only_con_region, pts=[contour_par], color=(1, 1, 1))
all_text_region_raw = textline_mask_tot_ea[y: y + h, x: x + w].copy()
mask_only_con_region = mask_only_con_region[y: y + h, x: x + w]
all_text_region_raw[mask_only_con_region == 0] = 0
cnt_clean_rot = textline_contours_postprocessing(all_text_region_raw, slope_for_all, contour_par, box_text)
return cnt_clean_rot, crop_coor, slope
@wrap_ndarray_shared(kw='textline_mask_tot_ea')
@wrap_ndarray_shared(kw='mask_texts_only')
def do_work_of_slopes_new_curved(
@ -1751,7 +1688,7 @@ def do_work_of_slopes_new_curved(
@wrap_ndarray_shared(kw='textline_mask_tot_ea')
def do_work_of_slopes_new_light(
box_text, contour, contour_par,
textline_mask_tot_ea=None, slope_deskew=0, textline_light=True,
textline_mask_tot_ea=None, slope_deskew=0,
logger=None
):
if logger is None:
@ -1768,16 +1705,10 @@ def do_work_of_slopes_new_light(
mask_only_con_region = np.zeros(textline_mask_tot_ea.shape)
mask_only_con_region = cv2.fillPoly(mask_only_con_region, pts=[contour_par], color=(1, 1, 1))
if textline_light:
all_text_region_raw = np.copy(textline_mask_tot_ea)
all_text_region_raw[mask_only_con_region == 0] = 0
cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(all_text_region_raw)
cnt_clean_rot = filter_contours_area_of_image(all_text_region_raw, cnt_clean_rot_raw, hir_on_cnt_clean_rot,
max_area=1, min_area=0.00001)
else:
all_text_region_raw = np.copy(textline_mask_tot_ea[y: y + h, x: x + w])
mask_only_con_region = mask_only_con_region[y: y + h, x: x + w]
all_text_region_raw[mask_only_con_region == 0] = 0
cnt_clean_rot = textline_contours_postprocessing(all_text_region_raw, slope_deskew, contour_par, box_text)
all_text_region_raw = np.copy(textline_mask_tot_ea)
all_text_region_raw[mask_only_con_region == 0] = 0
cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(all_text_region_raw)
cnt_clean_rot = filter_contours_area_of_image(all_text_region_raw, cnt_clean_rot_raw, hir_on_cnt_clean_rot,
max_area=1, min_area=0.00001)
return cnt_clean_rot, crop_coor, slope_deskew

View file

@ -128,6 +128,7 @@ def return_textlines_split_if_needed(textline_image, textline_image_bin=None):
return [image1, image2], None
else:
return None, None
def preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width):
if img.shape[0]==0 or img.shape[1]==0:
img_fin = np.ones((image_height, image_width, 3))
@ -379,7 +380,6 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
all_box_coord,
prediction_model,
b_s_ocr, num_to_char,
textline_light=False,
curved_line=False):
max_len = 512
padding_token = 299
@ -404,7 +404,7 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
else:
for indexing2, ind_poly in enumerate(ind_poly_first):
cropped_lines_region_indexer.append(indexer_text_region)
if not (textline_light or curved_line):
if not curved_line:
ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing]

View file

@ -88,3 +88,7 @@ def order_and_id_of_texts(found_polygons_text_region, found_polygons_text_region
order_of_texts.append(interest)
return order_of_texts, id_of_texts
def etree_namespace_for_element_tag(tag: str):
right = tag.find('}')
return tag[1:right]

View file

@ -2,15 +2,15 @@
# pylint: disable=import-error
from pathlib import Path
import os.path
import xml.etree.ElementTree as ET
import logging
from typing import Optional
import numpy as np
from shapely import affinity, clip_by_rect
from ocrd_utils import getLogger, points_from_polygon
from ocrd_utils import points_from_polygon
from ocrd_models.ocrd_page import (
BorderType,
CoordsType,
PcGtsType,
TextLineType,
TextEquivType,
TextRegionType,
@ -26,19 +26,18 @@ from .utils.contour import contour2polygon, make_valid
class EynollahXmlWriter:
def __init__(self, *, dir_out, image_filename, curved_line,textline_light, pcgts=None):
self.logger = getLogger('eynollah.writer')
def __init__(self, *, dir_out, image_filename, curved_line, pcgts=None):
self.logger = logging.getLogger('eynollah.writer')
self.counter = EynollahIdCounter()
self.dir_out = dir_out
self.image_filename = image_filename
self.output_filename = os.path.join(self.dir_out or "", self.image_filename_stem) + ".xml"
self.curved_line = curved_line
self.textline_light = textline_light
self.pcgts = pcgts
self.scale_x = None # XXX set outside __init__
self.scale_y = None # XXX set outside __init__
self.height_org = None # XXX set outside __init__
self.width_org = None # XXX set outside __init__
self.scale_x: Optional[float] = None # XXX set outside __init__
self.scale_y: Optional[float] = None # XXX set outside __init__
self.height_org: Optional[int] = None # XXX set outside __init__
self.width_org: Optional[int] = None # XXX set outside __init__
@property
def image_filename_stem(self):
@ -65,8 +64,8 @@ class EynollahXmlWriter:
text_region.set_orientation(-slopes[region_idx])
region_bboxes = all_box_coord[region_idx]
offset = [page_coord[2], page_coord[0]]
# FIXME: or actually... not self.textline_light and not self.curved_line or np.abs(slopes[region_idx]) > 45?
if not self.textline_light and not (self.curved_line and np.abs(slopes[region_idx]) <= 45):
# FIXME: or actually... self.curved_line or np.abs(slopes[region_idx]) > 45?
if self.curved_line and np.abs(slopes[region_idx]) > 45:
offset[0] += region_bboxes[2]
offset[1] += region_bboxes[0]
coords.set_points(self.calculate_points(polygon_textline, offset))
@ -77,48 +76,88 @@ class EynollahXmlWriter:
f.write(to_xml(pcgts))
def build_pagexml_no_full_layout(
self, found_polygons_text_region,
page_coord, order_of_texts,
all_found_textline_polygons,
all_box_coord,
found_polygons_text_region_img,
found_polygons_marginals_left, found_polygons_marginals_right,
all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left, all_box_coord_marginals_right,
slopes, slopes_marginals_left, slopes_marginals_right,
cont_page, polygons_seplines,
found_polygons_tables,
**kwargs):
self,
*,
found_polygons_text_region,
page_coord,
order_of_texts,
all_found_textline_polygons,
all_box_coord,
found_polygons_text_region_img,
found_polygons_marginals_left,
found_polygons_marginals_right,
all_found_textline_polygons_marginals_left,
all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left,
all_box_coord_marginals_right,
slopes,
slopes_marginals_left,
slopes_marginals_right,
cont_page,
polygons_seplines,
found_polygons_tables,
):
return self.build_pagexml_full_layout(
found_polygons_text_region, [],
page_coord, order_of_texts,
all_found_textline_polygons, [],
all_box_coord, [],
found_polygons_text_region_img, found_polygons_tables, [],
found_polygons_marginals_left, found_polygons_marginals_right,
all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left, all_box_coord_marginals_right,
slopes, [], slopes_marginals_left, slopes_marginals_right,
cont_page, polygons_seplines,
**kwargs)
found_polygons_text_region=found_polygons_text_region,
found_polygons_text_region_h=[],
page_coord=page_coord,
order_of_texts=order_of_texts,
all_found_textline_polygons=all_found_textline_polygons,
all_found_textline_polygons_h=[],
all_box_coord=all_box_coord,
all_box_coord_h=[],
found_polygons_text_region_img=found_polygons_text_region_img,
found_polygons_tables=found_polygons_tables,
found_polygons_drop_capitals=[],
found_polygons_marginals_left=found_polygons_marginals_left,
found_polygons_marginals_right=found_polygons_marginals_right,
all_found_textline_polygons_marginals_left=all_found_textline_polygons_marginals_left,
all_found_textline_polygons_marginals_right=all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left=all_box_coord_marginals_left,
all_box_coord_marginals_right=all_box_coord_marginals_right,
slopes=slopes,
slopes_h=[],
slopes_marginals_left=slopes_marginals_left,
slopes_marginals_right=slopes_marginals_right,
cont_page=cont_page,
polygons_seplines=polygons_seplines,
)
def build_pagexml_full_layout(
self,
found_polygons_text_region, found_polygons_text_region_h,
page_coord, order_of_texts,
all_found_textline_polygons, all_found_textline_polygons_h,
all_box_coord, all_box_coord_h,
found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals,
found_polygons_marginals_left,found_polygons_marginals_right,
all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left, all_box_coord_marginals_right,
slopes, slopes_h, slopes_marginals_left, slopes_marginals_right,
cont_page, polygons_seplines,
ocr_all_textlines=None, ocr_all_textlines_h=None,
ocr_all_textlines_marginals_left=None, ocr_all_textlines_marginals_right=None,
ocr_all_textlines_drop=None,
conf_contours_textregions=None, conf_contours_textregions_h=None,
skip_layout_reading_order=False):
self,
*,
found_polygons_text_region,
found_polygons_text_region_h,
page_coord,
order_of_texts,
all_found_textline_polygons,
all_found_textline_polygons_h,
all_box_coord,
all_box_coord_h,
found_polygons_text_region_img,
found_polygons_tables,
found_polygons_drop_capitals,
found_polygons_marginals_left,
found_polygons_marginals_right,
all_found_textline_polygons_marginals_left,
all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left,
all_box_coord_marginals_right,
slopes,
slopes_h,
slopes_marginals_left,
slopes_marginals_right,
cont_page,
polygons_seplines,
ocr_all_textlines=None,
ocr_all_textlines_h=None,
ocr_all_textlines_marginals_left=None,
ocr_all_textlines_marginals_right=None,
ocr_all_textlines_drop=None,
conf_contours_textregions=None,
conf_contours_textregions_h=None,
skip_layout_reading_order=False,
):
self.logger.debug('enter build_pagexml')
# create the file structure
@ -145,6 +184,7 @@ class EynollahXmlWriter:
id=counter.next_region_id, type_='paragraph',
Coords=CoordsType(points=self.calculate_points(region_contour, offset))
)
assert textregion.Coords
if conf_contours_textregions:
textregion.Coords.set_conf(conf_contours_textregions[mm])
page.add_TextRegion(textregion)
@ -161,6 +201,7 @@ class EynollahXmlWriter:
id=counter.next_region_id, type_='heading',
Coords=CoordsType(points=self.calculate_points(region_contour, offset))
)
assert textregion.Coords
if conf_contours_textregions_h:
textregion.Coords.set_conf(conf_contours_textregions_h[mm])
page.add_TextRegion(textregion)