mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
Merge 9801129aa6 into 1df32eba87
This commit is contained in:
commit
bcec0c4a55
27 changed files with 348 additions and 459 deletions
|
|
@ -1,7 +1,3 @@
|
||||||
# NOTE: For predictable order of imports of torch/shapely/tensorflow
|
|
||||||
# this must be the first import of the CLI!
|
|
||||||
from ..eynollah_imports import imported_libs
|
|
||||||
|
|
||||||
from .cli import main
|
from .cli import main
|
||||||
from .cli_binarize import binarize_cli
|
from .cli_binarize import binarize_cli
|
||||||
from .cli_enhance import enhance_cli
|
from .cli_enhance import enhance_cli
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ class EynollahCliCtx:
|
||||||
Holds options relevant for all eynollah subcommands
|
Holds options relevant for all eynollah subcommands
|
||||||
"""
|
"""
|
||||||
model_zoo: EynollahModelZoo
|
model_zoo: EynollahModelZoo
|
||||||
|
device: str = ''
|
||||||
log_level : Union[str, None] = 'INFO'
|
log_level : Union[str, None] = 'INFO'
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,6 +36,11 @@ class EynollahCliCtx:
|
||||||
type=(str, str, str),
|
type=(str, str, str),
|
||||||
multiple=True,
|
multiple=True,
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--device",
|
||||||
|
"-D",
|
||||||
|
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--log_level",
|
"--log_level",
|
||||||
"-l",
|
"-l",
|
||||||
|
|
@ -42,7 +48,7 @@ class EynollahCliCtx:
|
||||||
help="Override log level globally to this",
|
help="Override log level globally to this",
|
||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def main(ctx, model_basedir, model_overrides, log_level):
|
def main(ctx, model_basedir, model_overrides, device, log_level):
|
||||||
"""
|
"""
|
||||||
eynollah - Document Layout Analysis, Image Enhancement, OCR
|
eynollah - Document Layout Analysis, Image Enhancement, OCR
|
||||||
"""
|
"""
|
||||||
|
|
@ -58,6 +64,7 @@ def main(ctx, model_basedir, model_overrides, log_level):
|
||||||
# Initialize CLI context
|
# Initialize CLI context
|
||||||
ctx.obj = EynollahCliCtx(
|
ctx.obj = EynollahCliCtx(
|
||||||
model_zoo=model_zoo,
|
model_zoo=model_zoo,
|
||||||
|
device=device,
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@click.command()
|
@click.command(context_settings=dict(
|
||||||
|
help_option_names=['-h', '--help'],
|
||||||
|
show_default=True))
|
||||||
@click.option(
|
@click.option(
|
||||||
'--patches/--no-patches',
|
'--patches/--no-patches',
|
||||||
default=True,
|
default=True,
|
||||||
|
|
@ -31,11 +33,6 @@ import click
|
||||||
help="overwrite (instead of skipping) if output xml exists",
|
help="overwrite (instead of skipping) if output xml exists",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
)
|
)
|
||||||
@click.option(
|
|
||||||
"--device",
|
|
||||||
"-D",
|
|
||||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
|
||||||
)
|
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def binarize_cli(
|
def binarize_cli(
|
||||||
ctx,
|
ctx,
|
||||||
|
|
@ -44,14 +41,14 @@ def binarize_cli(
|
||||||
dir_in,
|
dir_in,
|
||||||
output,
|
output,
|
||||||
overwrite,
|
overwrite,
|
||||||
device,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Binarize images with a ML model
|
Binarize images with a ML model
|
||||||
"""
|
"""
|
||||||
from ..sbb_binarize import SbbBinarizer
|
from ..sbb_binarize import SbbBinarizer
|
||||||
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||||
binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo, device=device)
|
binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo,
|
||||||
|
device=ctx.obj.device)
|
||||||
binarizer.run(
|
binarizer.run(
|
||||||
image_filename=input_image,
|
image_filename=input_image,
|
||||||
use_patches=patches,
|
use_patches=patches,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@click.command()
|
@click.command(context_settings=dict(
|
||||||
|
help_option_names=['-h', '--help'],
|
||||||
|
show_default=True))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--image",
|
"--image",
|
||||||
"-i",
|
"-i",
|
||||||
|
|
@ -46,13 +48,8 @@ import click
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="save the enhanced image in original image size",
|
help="save the enhanced image in original image size",
|
||||||
)
|
)
|
||||||
@click.option(
|
|
||||||
"--device",
|
|
||||||
"-D",
|
|
||||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
|
||||||
)
|
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale, device):
|
def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale):
|
||||||
"""
|
"""
|
||||||
Enhance image
|
Enhance image
|
||||||
"""
|
"""
|
||||||
|
|
@ -60,10 +57,10 @@ def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower
|
||||||
from ..image_enhancer import Enhancer
|
from ..image_enhancer import Enhancer
|
||||||
enhancer = Enhancer(
|
enhancer = Enhancer(
|
||||||
model_zoo=ctx.obj.model_zoo,
|
model_zoo=ctx.obj.model_zoo,
|
||||||
|
device=ctx.obj.device,
|
||||||
num_col_upper=num_col_upper,
|
num_col_upper=num_col_upper,
|
||||||
num_col_lower=num_col_lower,
|
num_col_lower=num_col_lower,
|
||||||
save_org_scale=save_org_scale,
|
save_org_scale=save_org_scale,
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
enhancer.run(overwrite=overwrite,
|
enhancer.run(overwrite=overwrite,
|
||||||
dir_in=dir_in,
|
dir_in=dir_in,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@click.command()
|
@click.command(context_settings=dict(
|
||||||
|
help_option_names=['-h', '--help'],
|
||||||
|
show_default=True))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--image",
|
"--image",
|
||||||
"-i",
|
"-i",
|
||||||
|
|
@ -30,36 +32,40 @@ import click
|
||||||
@click.option(
|
@click.option(
|
||||||
"--save_images",
|
"--save_images",
|
||||||
"-si",
|
"-si",
|
||||||
help="if a directory is given, images in documents will be cropped and saved there",
|
help="if a directory is given, cropped images of pages will be saved there",
|
||||||
type=click.Path(exists=True, file_okay=False),
|
type=click.Path(exists=True, file_okay=False),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--enable-plotting/--disable-plotting",
|
"--enable-plotting",
|
||||||
"-ep/-noep",
|
"-ep",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="If set, will plot intermediary files and images",
|
help="plot intermediary diagnostic images to files",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--input_binary/--input-RGB",
|
"--input_binary",
|
||||||
"-ib/-irgb",
|
"-ib",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="In general, eynollah uses RGB as input but if the input document is very dark, very bright or for any other reason you can turn on input binarization. When this flag is set, eynollah will binarize the RGB input document, you should always provide RGB images to eynollah.",
|
help="In general, eynollah uses RGB as input, but if the input document is very dark, very bright or for any other reason you can turn on internal binarization here. When set, eynollah will binarize the RGB input document first.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--ignore_page_extraction/--extract_page_included",
|
"--ignore_page_extraction",
|
||||||
"-ipe/-epi",
|
"-ipe",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="if this parameter set to true, this tool would ignore page extraction",
|
help="ignore page extraction (cropping via page frame detection model)",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--num_col_upper",
|
"--num_col_upper",
|
||||||
"-ncu",
|
"-ncu",
|
||||||
help="lower limit of columns in document image",
|
default=0,
|
||||||
|
type=click.IntRange(min=0),
|
||||||
|
help="lower limit of columns in document image; 0 means autodetected from model",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--num_col_lower",
|
"--num_col_lower",
|
||||||
"-ncl",
|
"-ncl",
|
||||||
help="upper limit of columns in document image",
|
default=0,
|
||||||
|
type=click.IntRange(min=0),
|
||||||
|
help="upper limit of columns in document image; 0 means autodetected from model",
|
||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def extract_images_cli(
|
def extract_images_cli(
|
||||||
|
|
|
||||||
|
|
@ -172,11 +172,6 @@ import click
|
||||||
type=click.FloatRange(min=0),
|
type=click.FloatRange(min=0),
|
||||||
help="abort when number of failed images exceeds this value (if >=1) or ratio of failed over total images exceeds this value (if <1); 0 means ignore failures",
|
help="abort when number of failed images exceeds this value (if >=1) or ratio of failed over total images exceeds this value (if <1); 0 means ignore failures",
|
||||||
)
|
)
|
||||||
@click.option(
|
|
||||||
"--device",
|
|
||||||
"-D",
|
|
||||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
|
||||||
)
|
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def layout_cli(
|
def layout_cli(
|
||||||
ctx,
|
ctx,
|
||||||
|
|
@ -207,7 +202,6 @@ def layout_cli(
|
||||||
ignore_page_extraction,
|
ignore_page_extraction,
|
||||||
num_jobs,
|
num_jobs,
|
||||||
halt_fail,
|
halt_fail,
|
||||||
device,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Detect Layout (with optional image enhancement and reading order detection)
|
Detect Layout (with optional image enhancement and reading order detection)
|
||||||
|
|
@ -223,7 +217,7 @@ def layout_cli(
|
||||||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||||
eynollah = Eynollah(
|
eynollah = Eynollah(
|
||||||
model_zoo=ctx.obj.model_zoo,
|
model_zoo=ctx.obj.model_zoo,
|
||||||
device=device,
|
device=ctx.obj.device,
|
||||||
enable_plotting=enable_plotting,
|
enable_plotting=enable_plotting,
|
||||||
allow_enhancement=allow_enhancement,
|
allow_enhancement=allow_enhancement,
|
||||||
curved_line=curved_line,
|
curved_line=curved_line,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@click.command()
|
@click.command(context_settings=dict(
|
||||||
|
help_option_names=['-h', '--help'],
|
||||||
|
show_default=True))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--image",
|
"--image",
|
||||||
"-i",
|
"-i",
|
||||||
|
|
@ -16,7 +18,7 @@ import click
|
||||||
@click.option(
|
@click.option(
|
||||||
"--dir_in_bin",
|
"--dir_in_bin",
|
||||||
"-dib",
|
"-dib",
|
||||||
help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png' \n Perform prediction using both RGB and binary images. (This does not necessarily improve results, however it may be beneficial for certain document images."),
|
help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png'. \n Perform prediction using both RGB and binary images. (This may improve results for certain document images.)"),
|
||||||
type=click.Path(exists=True, file_okay=False),
|
type=click.Path(exists=True, file_okay=False),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|
@ -47,25 +49,29 @@ import click
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--tr_ocr",
|
"--tr_ocr",
|
||||||
"-trocr/-notrocr",
|
"-trocr",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.",
|
help="use transformer OCR (instead of classic CNN-RNN) model",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--do_not_mask_with_textline_contour",
|
"--do_not_mask_with_textline_contour",
|
||||||
"-nmtc/-mtc",
|
"-nmtc",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
|
help="skip masking each cropped textline image with its corresponding textline contour",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--batch_size",
|
"--batch_size",
|
||||||
"-bs",
|
"-bs",
|
||||||
|
default=0,
|
||||||
|
type=click.IntRange(min=0),
|
||||||
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
|
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--min_conf_value_of_textline_text",
|
"--min_conf_value_of_textline_text",
|
||||||
"-min_conf",
|
"-min_conf",
|
||||||
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
|
default=0.3,
|
||||||
|
type=click.FloatRange(min=0.0, max=1.0),
|
||||||
|
help="minimum OCR confidence threshold. Text lines with a lower confidence value will not be included in the output XML file.",
|
||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def ocr_cli(
|
def ocr_cli(
|
||||||
|
|
@ -85,14 +91,16 @@ def ocr_cli(
|
||||||
"""
|
"""
|
||||||
Recognize text with a CNN/RNN or transformer ML model.
|
Recognize text with a CNN/RNN or transformer ML model.
|
||||||
"""
|
"""
|
||||||
assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
|
assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
|
||||||
from ..eynollah_ocr import Eynollah_ocr
|
from ..eynollah_ocr import Eynollah_ocr
|
||||||
eynollah_ocr = Eynollah_ocr(
|
eynollah_ocr = Eynollah_ocr(
|
||||||
model_zoo=ctx.obj.model_zoo,
|
model_zoo=ctx.obj.model_zoo,
|
||||||
|
device=ctx.obj.device,
|
||||||
tr_ocr=tr_ocr,
|
tr_ocr=tr_ocr,
|
||||||
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
|
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
min_conf_value_of_textline_text=min_conf_value_of_textline_text)
|
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
|
||||||
|
)
|
||||||
eynollah_ocr.run(overwrite=overwrite,
|
eynollah_ocr.run(overwrite=overwrite,
|
||||||
dir_in=dir_in,
|
dir_in=dir_in,
|
||||||
dir_in_bin=dir_in_bin,
|
dir_in_bin=dir_in_bin,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@click.command()
|
@click.command(context_settings=dict(
|
||||||
|
help_option_names=['-h', '--help'],
|
||||||
|
show_default=True))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--input",
|
"--input",
|
||||||
"-i",
|
"-i",
|
||||||
|
|
@ -25,9 +27,10 @@ def readingorder_cli(ctx, input, dir_in, out):
|
||||||
"""
|
"""
|
||||||
Generate ReadingOrder with a ML model
|
Generate ReadingOrder with a ML model
|
||||||
"""
|
"""
|
||||||
from ..mb_ro_on_layout import machine_based_reading_order_on_layout
|
from ..mb_ro_on_layout import Reorder
|
||||||
assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||||
orderer = machine_based_reading_order_on_layout(model_zoo=ctx.obj.model_zoo)
|
orderer = Reorder(model_zoo=ctx.obj.model_zoo,
|
||||||
|
device=ctx.obj.device)
|
||||||
orderer.run(xml_filename=input,
|
orderer.run(xml_filename=input,
|
||||||
dir_in=dir_in,
|
dir_in=dir_in,
|
||||||
dir_out=out,
|
dir_out=out,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import os
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tensorflow as tf
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
@ -64,12 +63,6 @@ class EynollahImageExtractor(Eynollah):
|
||||||
|
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
|
|
||||||
try:
|
|
||||||
for device in tf.config.list_physical_devices('GPU'):
|
|
||||||
tf.config.experimental.set_memory_growth(device, True)
|
|
||||||
except:
|
|
||||||
self.logger.warning("no GPU device available")
|
|
||||||
|
|
||||||
self.logger.info("Loading models...")
|
self.logger.info("Loading models...")
|
||||||
self.setup_models()
|
self.setup_models()
|
||||||
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
||||||
|
|
|
||||||
|
|
@ -1148,7 +1148,6 @@ class Eynollah:
|
||||||
boxes,
|
boxes,
|
||||||
textline_mask_tot
|
textline_mask_tot
|
||||||
):
|
):
|
||||||
assert np.any(textline_mask_tot)
|
|
||||||
self.logger.debug("enter do_order_of_regions")
|
self.logger.debug("enter do_order_of_regions")
|
||||||
contours_only_text_parent = ensure_array(contours_only_text_parent)
|
contours_only_text_parent = ensure_array(contours_only_text_parent)
|
||||||
contours_only_text_parent_h = ensure_array(contours_only_text_parent_h)
|
contours_only_text_parent_h = ensure_array(contours_only_text_parent_h)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +0,0 @@
|
||||||
"""
|
|
||||||
Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
|
||||||
|
|
||||||
from ocrd_utils import tf_disable_interactive_logs
|
|
||||||
from torch import *
|
|
||||||
tf_disable_interactive_logs()
|
|
||||||
import tensorflow.keras
|
|
||||||
from shapely import *
|
|
||||||
imported_libs = True
|
|
||||||
__all__ = ['imported_libs']
|
|
||||||
|
|
@ -14,16 +14,14 @@ from cv2.typing import MatLike
|
||||||
from xml.etree import ElementTree as ET
|
from xml.etree import ElementTree as ET
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from eynollah.model_zoo import EynollahModelZoo
|
from ocrd_utils import polygon_from_points, xywh_from_polygon
|
||||||
from eynollah.utils.font import get_font
|
|
||||||
from eynollah.utils.xml import etree_namespace_for_element_tag
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
except ImportError:
|
|
||||||
torch = None
|
|
||||||
|
|
||||||
|
|
||||||
|
from .eynollah import Eynollah
|
||||||
|
from .model_zoo import EynollahModelZoo
|
||||||
from .utils import is_image_filename
|
from .utils import is_image_filename
|
||||||
|
from .utils.font import get_font
|
||||||
|
from .utils.xml import etree_namespace_for_element_tag
|
||||||
from .utils.resize import resize_image
|
from .utils.resize import resize_image
|
||||||
from .utils.utils_ocr import (
|
from .utils.utils_ocr import (
|
||||||
break_curved_line_into_small_pieces_and_then_merge,
|
break_curved_line_into_small_pieces_and_then_merge,
|
||||||
|
|
@ -34,6 +32,7 @@ from .utils.utils_ocr import (
|
||||||
preprocess_and_resize_image_for_ocrcnn_model,
|
preprocess_and_resize_image_for_ocrcnn_model,
|
||||||
return_textlines_split_if_needed,
|
return_textlines_split_if_needed,
|
||||||
rotate_image_with_padding,
|
rotate_image_with_padding,
|
||||||
|
batched,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: refine typing
|
# TODO: refine typing
|
||||||
|
|
@ -44,45 +43,44 @@ class EynollahOcrResult:
|
||||||
cropped_lines_region_indexer: List
|
cropped_lines_region_indexer: List
|
||||||
total_bb_coordinates:List
|
total_bb_coordinates:List
|
||||||
|
|
||||||
class Eynollah_ocr:
|
class Eynollah_ocr(Eynollah):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
model_zoo: EynollahModelZoo,
|
model_zoo: EynollahModelZoo,
|
||||||
tr_ocr=False,
|
tr_ocr=False,
|
||||||
batch_size: Optional[int]=None,
|
batch_size: int=0,
|
||||||
do_not_mask_with_textline_contour: bool=False,
|
do_not_mask_with_textline_contour: bool=False,
|
||||||
min_conf_value_of_textline_text : Optional[float]=None,
|
min_conf_value_of_textline_text : float=0.3,
|
||||||
logger: Optional[Logger]=None,
|
logger: Optional[Logger]=None,
|
||||||
|
device: str = '',
|
||||||
):
|
):
|
||||||
self.tr_ocr = tr_ocr
|
self.tr_ocr = tr_ocr
|
||||||
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
|
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
|
||||||
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
|
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
|
||||||
self.logger = logger if logger else getLogger('eynollah.ocr')
|
self.logger = logger if logger else getLogger('eynollah.ocr')
|
||||||
|
|
||||||
|
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
|
||||||
|
self.b_s = batch_size or 2 if tr_ocr else 8
|
||||||
|
|
||||||
self.model_zoo = model_zoo
|
self.model_zoo = model_zoo
|
||||||
|
self.setup_models(device=device)
|
||||||
|
|
||||||
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3
|
def setup_models(self, device=''):
|
||||||
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
|
if self.tr_ocr:
|
||||||
|
self.model_zoo.load_models('trocr_processor',
|
||||||
if tr_ocr:
|
('ocr', 'tr'),
|
||||||
self.model_zoo.load_models('trocr_processor')
|
device=device)
|
||||||
self.model_zoo.load_models(['ocr', 'tr'])
|
|
||||||
self.model_zoo.get('ocr').to(self.device)
|
|
||||||
else:
|
else:
|
||||||
self.model_zoo.load_models('ocr')
|
self.model_zoo.load_models('ocr',
|
||||||
self.model_zoo.load_models('num_to_char')
|
'num_to_char',
|
||||||
self.model_zoo.load_models('characters')
|
'characters',
|
||||||
|
device=device)
|
||||||
self.end_character = len(self.model_zoo.get('characters')) + 2
|
self.end_character = len(self.model_zoo.get('characters')) + 2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
assert torch
|
return self.model_zoo.get('ocr').device
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.logger.info("Using GPU acceleration")
|
|
||||||
return torch.device("cuda:0")
|
|
||||||
else:
|
|
||||||
self.logger.info("Using CPU processing")
|
|
||||||
return torch.device("cpu")
|
|
||||||
|
|
||||||
def run_trocr(
|
def run_trocr(
|
||||||
self,
|
self,
|
||||||
|
|
@ -94,174 +92,94 @@ class Eynollah_ocr:
|
||||||
) -> EynollahOcrResult:
|
) -> EynollahOcrResult:
|
||||||
|
|
||||||
total_bb_coordinates = []
|
total_bb_coordinates = []
|
||||||
|
|
||||||
|
|
||||||
cropped_lines = []
|
cropped_lines = []
|
||||||
cropped_lines_region_indexer = []
|
cropped_lines_region_indexer = []
|
||||||
cropped_lines_meging_indexing = []
|
cropped_lines_meging_indexing = []
|
||||||
|
|
||||||
extracted_texts = []
|
extracted_texts = []
|
||||||
|
extracted_confs = []
|
||||||
|
|
||||||
indexer_text_region = 0
|
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||||
indexer_b_s = 0
|
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||||
|
cropped_lines_region_indexer.append(n_region)
|
||||||
|
|
||||||
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
|
coords = line.find('{%s}Coords' % page_ns)
|
||||||
for child_textregion in nn:
|
if coords is None:
|
||||||
if child_textregion.tag.endswith("TextLine"):
|
self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
|
||||||
|
continue
|
||||||
|
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
|
||||||
|
cont = poly[:, np.newaxis]
|
||||||
|
xywh = xywh_from_polygon(poly)
|
||||||
|
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
|
||||||
|
|
||||||
for child_textlines in child_textregion:
|
total_bb_coordinates.append([x, y, w, h])
|
||||||
if child_textlines.tag.endswith("Coords"):
|
|
||||||
cropped_lines_region_indexer.append(indexer_text_region)
|
|
||||||
p_h=child_textlines.attrib['points'].split(' ')
|
|
||||||
textline_coords = np.array( [ [int(x.split(',')[0]),
|
|
||||||
int(x.split(',')[1]) ]
|
|
||||||
for x in p_h] )
|
|
||||||
x,y,w,h = cv2.boundingRect(textline_coords)
|
|
||||||
|
|
||||||
total_bb_coordinates.append([x,y,w,h])
|
img_crop = img[y: y + h, x: x + w]
|
||||||
|
if not self.do_not_mask_with_textline_contour:
|
||||||
|
mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8)
|
||||||
|
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
|
||||||
|
img_crop[mask_poly == 0] = 255 # FIXME: or median color?
|
||||||
|
|
||||||
h2w_ratio = h/float(w)
|
if h > 0.1 * w:
|
||||||
|
cropped_lines.append(resize_image(img_crop,
|
||||||
img_poly_on_img = np.copy(img)
|
tr_ocr_input_height_and_width,
|
||||||
mask_poly = np.zeros(img.shape)
|
tr_ocr_input_height_and_width) )
|
||||||
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
cropped_lines_meging_indexing.append(0)
|
||||||
|
else:
|
||||||
mask_poly = mask_poly[y:y+h, x:x+w, :]
|
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
|
||||||
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
|
if splited_images:
|
||||||
img_crop[mask_poly==0] = 255
|
cropped_lines.append(resize_image(splited_images[0],
|
||||||
|
tr_ocr_input_height_and_width,
|
||||||
self.logger.debug("processing %d lines for '%s'",
|
tr_ocr_input_height_and_width))
|
||||||
len(cropped_lines), nn.attrib['id'])
|
cropped_lines_meging_indexing.append(1)
|
||||||
if h2w_ratio > 0.1:
|
cropped_lines.append(resize_image(splited_images[1],
|
||||||
cropped_lines.append(resize_image(img_crop,
|
tr_ocr_input_height_and_width,
|
||||||
tr_ocr_input_height_and_width,
|
tr_ocr_input_height_and_width))
|
||||||
tr_ocr_input_height_and_width) )
|
cropped_lines_meging_indexing.append(-1)
|
||||||
cropped_lines_meging_indexing.append(0)
|
else:
|
||||||
indexer_b_s+=1
|
cropped_lines.append(img_crop)
|
||||||
if indexer_b_s==self.b_s:
|
cropped_lines_meging_indexing.append(0)
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
|
||||||
pixel_values_merged.to(self.device))
|
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
generated_ids_merged, skip_special_tokens=True)
|
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
else:
|
|
||||||
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
|
|
||||||
#print(splited_images)
|
|
||||||
if splited_images:
|
|
||||||
cropped_lines.append(resize_image(splited_images[0],
|
|
||||||
tr_ocr_input_height_and_width,
|
|
||||||
tr_ocr_input_height_and_width))
|
|
||||||
cropped_lines_meging_indexing.append(1)
|
|
||||||
indexer_b_s+=1
|
|
||||||
|
|
||||||
if indexer_b_s==self.b_s:
|
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
|
||||||
pixel_values_merged.to(self.device))
|
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
generated_ids_merged, skip_special_tokens=True)
|
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
|
|
||||||
cropped_lines.append(resize_image(splited_images[1],
|
self.logger.debug("processing %d lines for %d regions",
|
||||||
tr_ocr_input_height_and_width,
|
len(cropped_lines), len(set(cropped_lines_region_indexer)))
|
||||||
tr_ocr_input_height_and_width))
|
for imgs in batched(cropped_lines, self.b_s):
|
||||||
cropped_lines_meging_indexing.append(-1)
|
pixel_values = self.model_zoo.get('trocr_processor')(
|
||||||
indexer_b_s+=1
|
imgs, return_tensors="pt").pixel_values
|
||||||
|
output = self.model_zoo.get('ocr').generate(
|
||||||
if indexer_b_s==self.b_s:
|
pixel_values.to(self.device),
|
||||||
imgs = cropped_lines[:]
|
# beam search instead of greedy decoding:
|
||||||
cropped_lines = []
|
num_beams=4,
|
||||||
indexer_b_s = 0
|
# also return probability
|
||||||
|
output_scores=True,
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
return_dict_in_generate=True)
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
if output.sequences_scores is not None:
|
||||||
pixel_values_merged.to(self.device))
|
# log-prob averaged over length
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
|
||||||
generated_ids_merged, skip_special_tokens=True)
|
else:
|
||||||
|
conf = [1.0] * len(output.sequences)
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
text = self.model_zoo.get('trocr_processor').batch_decode(
|
||||||
|
output.sequences,
|
||||||
else:
|
skip_special_tokens=True,
|
||||||
cropped_lines.append(img_crop)
|
clean_up_tokenization_spaces=False)
|
||||||
cropped_lines_meging_indexing.append(0)
|
extracted_confs.extend(conf)
|
||||||
indexer_b_s+=1
|
extracted_texts.extend(text)
|
||||||
|
|
||||||
if indexer_b_s==self.b_s:
|
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
|
||||||
pixel_values_merged.to(self.device))
|
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
generated_ids_merged, skip_special_tokens=True)
|
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
indexer_text_region = indexer_text_region +1
|
|
||||||
|
|
||||||
if indexer_b_s!=0:
|
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device))
|
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True)
|
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
####extracted_texts = []
|
|
||||||
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
|
||||||
|
|
||||||
####for i in range(n_iterations):
|
|
||||||
####if i==(n_iterations-1):
|
|
||||||
####n_start = i*self.b_s
|
|
||||||
####imgs = cropped_lines[n_start:]
|
|
||||||
####else:
|
|
||||||
####n_start = i*self.b_s
|
|
||||||
####n_end = (i+1)*self.b_s
|
|
||||||
####imgs = cropped_lines[n_start:n_end]
|
|
||||||
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
|
||||||
####generated_ids_merged = self.model_ocr.generate(
|
|
||||||
#### pixel_values_merged.to(self.device))
|
|
||||||
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
#### generated_ids_merged, skip_special_tokens=True)
|
|
||||||
|
|
||||||
####extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
del cropped_lines
|
del cropped_lines
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
extracted_texts_merged = [extracted_texts[ind]
|
extracted_texts_merged = [extracted_texts[ind]
|
||||||
if cropped_lines_meging_indexing[ind]==0
|
if cropped_lines_meging_indexing[ind] == 0
|
||||||
else extracted_texts[ind]+" "+extracted_texts[ind+1]
|
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
|
||||||
if cropped_lines_meging_indexing[ind]==1
|
for ind in range(len(cropped_lines_meging_indexing))
|
||||||
else None
|
if cropped_lines_meging_indexing[ind] >= 0]
|
||||||
for ind in range(len(cropped_lines_meging_indexing))]
|
extracted_confs_merged = [extracted_confs[ind]
|
||||||
|
if cropped_lines_meging_indexing[ind] == 0
|
||||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1])
|
||||||
#print(extracted_texts_merged, len(extracted_texts_merged))
|
for ind in range(len(cropped_lines_meging_indexing))
|
||||||
|
if cropped_lines_meging_indexing[ind] >= 0]
|
||||||
|
|
||||||
return EynollahOcrResult(
|
return EynollahOcrResult(
|
||||||
extracted_texts_merged=extracted_texts_merged,
|
extracted_texts_merged=extracted_texts_merged,
|
||||||
extracted_conf_value_merged=None,
|
extracted_conf_value_merged=extracted_confs_merged,
|
||||||
cropped_lines_region_indexer=cropped_lines_region_indexer,
|
cropped_lines_region_indexer=cropped_lines_region_indexer,
|
||||||
total_bb_coordinates=total_bb_coordinates,
|
total_bb_coordinates=total_bb_coordinates,
|
||||||
)
|
)
|
||||||
|
|
@ -717,6 +635,7 @@ class Eynollah_ocr:
|
||||||
|
|
||||||
has_textline = False
|
has_textline = False
|
||||||
for child_textregion in nn:
|
for child_textregion in nn:
|
||||||
|
# FIXME: should remove Word level, if it already exists
|
||||||
if child_textregion.tag.endswith("TextLine"):
|
if child_textregion.tag.endswith("TextLine"):
|
||||||
|
|
||||||
is_textline_text = False
|
is_textline_text = False
|
||||||
|
|
@ -754,6 +673,7 @@ class Eynollah_ocr:
|
||||||
indexer_textregion = indexer_textregion + 1
|
indexer_textregion = indexer_textregion + 1
|
||||||
|
|
||||||
ET.register_namespace("",page_ns)
|
ET.register_namespace("",page_ns)
|
||||||
|
self.logger.info("output filename: '%s'", out_file_ocr)
|
||||||
page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None)
|
page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,7 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import statistics
|
import statistics
|
||||||
|
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
from .eynollah import Eynollah
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
from .model_zoo import EynollahModelZoo
|
from .model_zoo import EynollahModelZoo
|
||||||
from .utils.resize import resize_image
|
from .utils.resize import resize_image
|
||||||
from .utils.contour import (
|
from .utils.contour import (
|
||||||
|
|
@ -33,23 +31,27 @@ DPI_THRESHOLD = 298
|
||||||
KERNEL = np.ones((5, 5), np.uint8)
|
KERNEL = np.ones((5, 5), np.uint8)
|
||||||
|
|
||||||
|
|
||||||
class machine_based_reading_order_on_layout:
|
class Reorder(Eynollah):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
model_zoo: EynollahModelZoo,
|
model_zoo: EynollahModelZoo,
|
||||||
logger : Optional[logging.Logger] = None,
|
logger : Optional[logging.Logger] = None,
|
||||||
|
device: str = '',
|
||||||
):
|
):
|
||||||
self.logger = logger or logging.getLogger('eynollah.mbreorder')
|
self.logger = logger or logging.getLogger('eynollah.mbreorder')
|
||||||
self.model_zoo = model_zoo
|
self.model_zoo = model_zoo
|
||||||
|
|
||||||
try:
|
self.model_zoo.load_model('reading_order')
|
||||||
for device in tf.config.list_physical_devices('GPU'):
|
self.setup_models(device=device)
|
||||||
tf.config.experimental.set_memory_growth(device, True)
|
|
||||||
except:
|
def setup_models(self, device=''):
|
||||||
self.logger.warning("no GPU device available")
|
loadable = ['reading_order']
|
||||||
|
self.model_zoo.load_models(*loadable, device=device)
|
||||||
|
for model in loadable:
|
||||||
|
self.logger.debug("model %s has input shape %s", model,
|
||||||
|
self.model_zoo.get(model).input_shape)
|
||||||
|
|
||||||
self.model_zoo.load_models('reading_order')
|
|
||||||
|
|
||||||
def read_xml(self, xml_file):
|
def read_xml(self, xml_file):
|
||||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||||
|
|
@ -675,7 +677,7 @@ class machine_based_reading_order_on_layout:
|
||||||
tot_counter += 1
|
tot_counter += 1
|
||||||
batch.append(j)
|
batch.append(j)
|
||||||
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
||||||
y_pr = self.model_zoo.get('reading_order').predict(input_1 , verbose='0')
|
y_pr = self.model_zoo.get('reading_order').predict(input_1, verbose=0)
|
||||||
for jb, j in enumerate(batch):
|
for jb, j in enumerate(batch):
|
||||||
if y_pr[jb][0]>=0.5:
|
if y_pr[jb][0]>=0.5:
|
||||||
post_list.append(j)
|
post_list.append(j)
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -35,7 +35,7 @@ class EynollahModelZoo:
|
||||||
self._overrides = []
|
self._overrides = []
|
||||||
if model_overrides:
|
if model_overrides:
|
||||||
self.override_models(*model_overrides)
|
self.override_models(*model_overrides)
|
||||||
self._loaded: Dict[str, Predictor] = {}
|
self._loaded: Dict[str, Union[Predictor, AnyModel]] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_overrides(self):
|
def model_overrides(self):
|
||||||
|
|
@ -70,6 +70,9 @@ class EynollahModelZoo:
|
||||||
model_path = Path(self.model_basedir).joinpath(spec.filename)
|
model_path = Path(self.model_basedir).joinpath(spec.filename)
|
||||||
else:
|
else:
|
||||||
model_path = Path(spec.filename)
|
model_path = Path(spec.filename)
|
||||||
|
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
||||||
|
# prefer SavedModel over HDF5 format if it exists
|
||||||
|
model_path = Path(model_path.stem)
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
def load_models(
|
def load_models(
|
||||||
|
|
@ -82,32 +85,50 @@ class EynollahModelZoo:
|
||||||
"""
|
"""
|
||||||
ret = {} # cannot use self._loaded here, yet – first spawn all predictors
|
ret = {} # cannot use self._loaded here, yet – first spawn all predictors
|
||||||
for load_args in all_load_args:
|
for load_args in all_load_args:
|
||||||
|
load_kwargs = dict(device=device)
|
||||||
if isinstance(load_args, str):
|
if isinstance(load_args, str):
|
||||||
model_category = load_args
|
model_category, model_variant = load_args, ""
|
||||||
load_args = [model_category]
|
elif len(load_args) > 2:
|
||||||
|
# for calls to self.model_path
|
||||||
|
self.override_models(load_args)
|
||||||
|
# for calls to Predictor.load_model
|
||||||
|
model_category, model_variant, model_path = load_args
|
||||||
|
load_kwargs["model_variant"] = model_variant
|
||||||
|
load_kwargs["model_path_override"] = model_path
|
||||||
else:
|
else:
|
||||||
model_category = load_args[0]
|
model_category, model_variant = load_args
|
||||||
load_kwargs = {}
|
load_kwargs["model_variant"] = model_variant
|
||||||
|
|
||||||
if model_category.endswith('_resized'):
|
if model_category.endswith('_resized'):
|
||||||
load_args[0] = model_category[:-8]
|
model_category = model_category[:-8]
|
||||||
load_kwargs["resized"] = True
|
load_kwargs["resized"] = True
|
||||||
elif model_category.endswith('_patched'):
|
elif model_category.endswith('_patched'):
|
||||||
load_args[0] = model_category[:-8]
|
model_category = model_category[:-8]
|
||||||
load_kwargs["patched"] = True
|
load_kwargs["patched"] = True
|
||||||
spec = self.specs.get(model_category, load_args[1] if len(load_args) > 1 else '')
|
|
||||||
if spec.type in ['Keras'] and spec.category != 'ocr':
|
if model_category == 'ocr':
|
||||||
ret[model_category] = Predictor(self.logger, self)
|
model = self._load_ocr_model(variant=model_variant, device=device)
|
||||||
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
|
elif model_category == 'num_to_char':
|
||||||
|
model = self._load_num_to_char()
|
||||||
|
elif model_category == 'characters':
|
||||||
|
model = self._load_characters()
|
||||||
|
elif model_category == 'trocr_processor':
|
||||||
|
from transformers import TrOCRProcessor
|
||||||
|
model_path = self.model_path(model_category, model_variant)
|
||||||
|
model = TrOCRProcessor.from_pretrained(model_path)
|
||||||
else:
|
else:
|
||||||
ret[model_category] = self.load_model(*load_args, **load_kwargs, device=device)
|
model = Predictor(self.logger, self)
|
||||||
|
model.load_model(model_category, **load_kwargs)
|
||||||
|
|
||||||
|
ret[model_category] = model
|
||||||
self._loaded.update(ret)
|
self._loaded.update(ret)
|
||||||
return self._loaded
|
return self._loaded
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self,
|
self,
|
||||||
model_category: str,
|
model_category: str,
|
||||||
model_variant: str = '',
|
model_variant: str = '',
|
||||||
model_path_override: Optional[str] = None,
|
model_path_override: Optional[str] = None,
|
||||||
patched: bool = False,
|
patched: bool = False,
|
||||||
resized: bool = False,
|
resized: bool = False,
|
||||||
device: str = '',
|
device: str = '',
|
||||||
|
|
@ -121,6 +142,7 @@ class EynollahModelZoo:
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
|
from tensorflow.keras.models import Model as KerasModel
|
||||||
|
|
||||||
from ..patch_encoder import (
|
from ..patch_encoder import (
|
||||||
PatchEncoder,
|
PatchEncoder,
|
||||||
|
|
@ -132,7 +154,7 @@ class EynollahModelZoo:
|
||||||
try:
|
try:
|
||||||
gpus = tf.config.list_physical_devices('GPU')
|
gpus = tf.config.list_physical_devices('GPU')
|
||||||
if device:
|
if device:
|
||||||
if ',' in device:
|
if ':' in device:
|
||||||
for spec in device.split(','):
|
for spec in device.split(','):
|
||||||
cat, dev = spec.split(':')
|
cat, dev = spec.split(':')
|
||||||
if fnmatchcase(model_category, cat):
|
if fnmatchcase(model_category, cat):
|
||||||
|
|
@ -147,7 +169,24 @@ class EynollahModelZoo:
|
||||||
gpus = gpus[:1] # TF will always use first allowable
|
gpus = gpus[:1] # TF will always use first allowable
|
||||||
tf.config.set_visible_devices(gpus, 'GPU')
|
tf.config.set_visible_devices(gpus, 'GPU')
|
||||||
for device in gpus:
|
for device in gpus:
|
||||||
tf.config.experimental.set_memory_growth(device, True)
|
# tf.config.experimental.set_memory_growth(device, True)
|
||||||
|
# dynamic growth never frees memory (to avoid fragmentation),
|
||||||
|
# so the VRAM requirements end up much larger than feasible
|
||||||
|
# (for small GPUs); so try hard (calibrated) limits instead:
|
||||||
|
tf.config.set_logical_device_configuration(
|
||||||
|
device,
|
||||||
|
[tf.config.LogicalDeviceConfiguration(memory_limit={
|
||||||
|
"binarization": 868, # due to bs 5
|
||||||
|
"enhancement": 980, # due to bs 3
|
||||||
|
"col_classifier": 210,
|
||||||
|
"page": 618,
|
||||||
|
"textline": 1680, # 954 for bs 1
|
||||||
|
"region_1_2": 1580,
|
||||||
|
"region_fl_np": 1756,
|
||||||
|
"table": 1818,
|
||||||
|
"reading_order": 632,
|
||||||
|
"ocr": 850,
|
||||||
|
}[model_category])])
|
||||||
vendor_name = (
|
vendor_name = (
|
||||||
tf.config.experimental.get_device_details(device)
|
tf.config.experimental.get_device_details(device)
|
||||||
.get('device_name', 'unknown'))
|
.get('device_name', 'unknown'))
|
||||||
|
|
@ -166,65 +205,76 @@ class EynollahModelZoo:
|
||||||
if model_path_override:
|
if model_path_override:
|
||||||
self.override_models((model_category, model_variant, model_path_override))
|
self.override_models((model_category, model_variant, model_path_override))
|
||||||
model_path = self.model_path(model_category, model_variant)
|
model_path = self.model_path(model_category, model_variant)
|
||||||
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
|
try:
|
||||||
# prefer SavedModel over HDF5 format if it exists
|
if model_path.is_dir() and not (model_path / "keras_metadata.pb").exists():
|
||||||
model_path = Path(model_path.stem)
|
# short-cut to avoid warning for exported models
|
||||||
if model_category == 'ocr':
|
raise ValueError()
|
||||||
model = self._load_ocr_model(variant=model_variant)
|
model = load_model(model_path, compile=False)
|
||||||
elif model_category == 'num_to_char':
|
|
||||||
model = self._load_num_to_char()
|
|
||||||
elif model_category == 'characters':
|
|
||||||
model = self._load_characters()
|
|
||||||
elif model_category == 'trocr_processor':
|
|
||||||
from transformers import TrOCRProcessor
|
|
||||||
model = TrOCRProcessor.from_pretrained(model_path)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# avoid wasting VRAM on non-transformer models
|
|
||||||
model = load_model(model_path, compile=False)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(e)
|
|
||||||
model = load_model(
|
|
||||||
model_path, compile=False,
|
|
||||||
custom_objects=dict(PatchEncoder=PatchEncoder,
|
|
||||||
Patches=Patches))
|
|
||||||
model._name = model_category
|
|
||||||
if resized:
|
|
||||||
model = wrap_layout_model_resized(model)
|
|
||||||
model._name = model_category + '_resized'
|
|
||||||
elif patched:
|
|
||||||
model = wrap_layout_model_patched(model)
|
|
||||||
model._name = model_category + '_patched'
|
|
||||||
else:
|
|
||||||
model.jit_compile = True
|
|
||||||
model.make_predict_function()
|
model.make_predict_function()
|
||||||
|
except (AttributeError, ValueError):
|
||||||
|
model = tf.saved_model.load(model_path)
|
||||||
|
model.predict_on_batch = model.serve
|
||||||
|
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
|
||||||
|
model._name = model_category
|
||||||
|
if resized:
|
||||||
|
model = wrap_layout_model_resized(model)
|
||||||
|
model._name = model_category + '_resized'
|
||||||
|
elif patched:
|
||||||
|
model = wrap_layout_model_patched(model)
|
||||||
|
model._name = model_category + '_patched'
|
||||||
|
else:
|
||||||
|
# increases required VRAM, does not always work
|
||||||
|
# (depending on CUDA/libcudnn/TF version):
|
||||||
|
#model.jit_compile = True
|
||||||
|
pass
|
||||||
|
|
||||||
|
if model_category == 'ocr':
|
||||||
|
model = KerasModel(
|
||||||
|
model.get_layer(name="image").input, # type: ignore
|
||||||
|
model.get_layer(name="dense2").output, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get(self, model_category: str) -> Predictor:
|
def get(self, model_category: str) -> Union[Predictor, AnyModel]:
|
||||||
if model_category not in self._loaded:
|
if model_category not in self._loaded:
|
||||||
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
|
||||||
return self._loaded[model_category]
|
return self._loaded[model_category]
|
||||||
|
|
||||||
def _load_ocr_model(self, variant: str) -> AnyModel:
|
def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel:
|
||||||
"""
|
"""
|
||||||
Load OCR model
|
Load OCR model
|
||||||
"""
|
"""
|
||||||
from tensorflow.keras.models import Model as KerasModel
|
model_dir = self.model_path('ocr', variant)
|
||||||
from tensorflow.keras.models import load_model
|
|
||||||
|
|
||||||
ocr_model_dir = self.model_path('ocr', variant)
|
|
||||||
if variant == 'tr':
|
if variant == 'tr':
|
||||||
from transformers import VisionEncoderDecoderModel
|
from transformers import VisionEncoderDecoderModel
|
||||||
ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
import torch
|
||||||
assert isinstance(ret, VisionEncoderDecoderModel)
|
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
|
||||||
return ret
|
assert isinstance(model, VisionEncoderDecoderModel)
|
||||||
else:
|
device0 = torch.device('cpu')
|
||||||
ocr_model = load_model(ocr_model_dir, compile=False)
|
if not device and torch.cuda.is_available():
|
||||||
assert isinstance(ocr_model, KerasModel)
|
device = 'GPU' # try
|
||||||
return KerasModel(
|
if device and ':' in device:
|
||||||
ocr_model.get_layer(name="image").input, # type: ignore
|
for spec in device.split(','):
|
||||||
ocr_model.get_layer(name="dense2").output, # type: ignore
|
cat, dev = spec.split(':')
|
||||||
)
|
if fnmatchcase('ocr', cat):
|
||||||
|
device = dev
|
||||||
|
break
|
||||||
|
if device and device.startswith('GPU'):
|
||||||
|
try:
|
||||||
|
device0 = torch.device('cuda', int(device[3:] or 0))
|
||||||
|
name = torch.cuda.get_device_name(device0)
|
||||||
|
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
|
||||||
|
except:
|
||||||
|
self.logger.exception("cannot configure GPU device")
|
||||||
|
device0 = torch.device('cpu')
|
||||||
|
if device0.type == 'cuda':
|
||||||
|
model.to(device0)
|
||||||
|
else:
|
||||||
|
self.logger.warning("no GPU device available")
|
||||||
|
return model
|
||||||
|
|
||||||
|
return self.load_model('ocr', model_variant=variant, device=device)
|
||||||
|
|
||||||
def _load_characters(self) -> List[str]:
|
def _load_characters(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -237,6 +287,10 @@ class EynollahModelZoo:
|
||||||
"""
|
"""
|
||||||
Load decoder for OCR
|
Load decoder for OCR
|
||||||
"""
|
"""
|
||||||
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
|
from ocrd_utils import tf_disable_interactive_logs
|
||||||
|
tf_disable_interactive_logs()
|
||||||
|
|
||||||
from tensorflow.keras.layers import StringLookup
|
from tensorflow.keras.layers import StringLookup
|
||||||
|
|
||||||
characters = self._load_characters()
|
characters = self._load_characters()
|
||||||
|
|
@ -277,5 +331,5 @@ class EynollahModelZoo:
|
||||||
"""
|
"""
|
||||||
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||||||
for needle in list(self._loaded.keys()):
|
for needle in list(self._loaded.keys()):
|
||||||
self._loaded[needle].shutdown()
|
if isinstance(self._loaded[needle], Predictor):
|
||||||
del self._loaded[needle]
|
self._loaded[needle].shutdown()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
# NOTE: For predictable order of imports of torch/shapely/tensorflow
|
|
||||||
# this must be the first import of the CLI!
|
|
||||||
from .eynollah_imports import imported_libs
|
|
||||||
from .processor import EynollahProcessor
|
|
||||||
from click import command
|
from click import command
|
||||||
from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor
|
from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor
|
||||||
|
|
||||||
|
from .processor import EynollahProcessor
|
||||||
|
|
||||||
@command()
|
@command()
|
||||||
@ocrd_cli_options
|
@ocrd_cli_options
|
||||||
def main(*args, **kwargs):
|
def main(*args, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -194,17 +194,18 @@ class Predictor(mp.context.SpawnProcess):
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
# do not terminate from forked processor instances
|
# do not terminate from forked processor instances
|
||||||
if mp.parent_process() is None:
|
if not hasattr(self, 'model'):
|
||||||
self.stopped.set()
|
self.stopped.set()
|
||||||
|
self.join()
|
||||||
self.taskq.close()
|
self.taskq.close()
|
||||||
self.taskq.cancel_join_thread()
|
self.taskq.cancel_join_thread()
|
||||||
self.resultq.close()
|
self.resultq.close()
|
||||||
self.resultq.cancel_join_thread()
|
self.resultq.cancel_join_thread()
|
||||||
self.logq.close()
|
self.logq.close()
|
||||||
self.terminate()
|
#self.terminate()
|
||||||
else:
|
else:
|
||||||
del self.model
|
del self.model
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
#self.logger.debug(f"deinit of {self} in {mp.current_process().name}")
|
#self.logger.debug(f"deinit of {self.name} in {mp.current_process().name}")
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|
|
||||||
|
|
@ -309,11 +309,10 @@ def transformer_block(img,
|
||||||
# Skip connection 2.
|
# Skip connection 2.
|
||||||
encoded_patches = Add()([x3, x2])
|
encoded_patches = Add()([x3, x2])
|
||||||
|
|
||||||
encoded_patches = tf.reshape(encoded_patches,
|
encoded_patches = Reshape(target_shape=(img.shape[1],
|
||||||
[-1,
|
img.shape[2],
|
||||||
img.shape[1],
|
projection_dim // (patchsize_x * patchsize_y)),
|
||||||
img.shape[2],
|
name="reshape_patches")(encoded_patches)
|
||||||
projection_dim // (patchsize_x * patchsize_y)])
|
|
||||||
return encoded_patches
|
return encoded_patches
|
||||||
|
|
||||||
def vit_resnet50_unet(num_patches,
|
def vit_resnet50_unet(num_patches,
|
||||||
|
|
|
||||||
|
|
@ -26,16 +26,17 @@ RELOADABLE_MODELS = \
|
||||||
all: $(RELOADABLE_MODELS)
|
all: $(RELOADABLE_MODELS)
|
||||||
|
|
||||||
$(MODELS_DST)/%: $(MODELS_SRC)/%
|
$(MODELS_DST)/%: $(MODELS_SRC)/%
|
||||||
mkdir -p $@
|
|
||||||
test -e $</config.json || exit 1
|
test -e $</config.json || exit 1
|
||||||
eynollah-training train --force \
|
{ mkdir -p $@ \
|
||||||
|
&& eynollah-training train --force \
|
||||||
with $</config.json \
|
with $</config.json \
|
||||||
reload_weights=True \
|
reload_weights=True \
|
||||||
continue_training=False \
|
continue_training=False \
|
||||||
dir_output=$(dir $@) \
|
dir_output=$(dir $@) \
|
||||||
dir_of_start_model=$< \
|
dir_of_start_model=$< \
|
||||||
|
&& cp $</config.json $@/config.json \
|
||||||
|
|| { rm -rf $@; false; }; } \
|
||||||
2>&1 | tee $(notdir $<).log
|
2>&1 | tee $(notdir $<).log
|
||||||
cp $</config.json $@/config.json
|
|
||||||
|
|
||||||
compare:
|
compare:
|
||||||
for i in `find $(MODELS_DST) -mindepth 2`;do \
|
for i in `find $(MODELS_DST) -mindepth 2`;do \
|
||||||
|
|
|
||||||
|
|
@ -562,7 +562,8 @@ def run(_config,
|
||||||
if reload_weights:
|
if reload_weights:
|
||||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||||
model.save(dir_save, include_optimizer=False)
|
#model.save(dir_save, include_optimizer=False)
|
||||||
|
model.export(dir_save)
|
||||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||||
|
|
@ -725,7 +726,8 @@ def run(_config,
|
||||||
if reload_weights:
|
if reload_weights:
|
||||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||||
model.save(dir_save, include_optimizer=False)
|
#model.save(dir_save, include_optimizer=False)
|
||||||
|
model.export(dir_save)
|
||||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||||
|
|
@ -843,7 +845,8 @@ def run(_config,
|
||||||
if reload_weights:
|
if reload_weights:
|
||||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
||||||
model.save(dir_save, include_optimizer=False)
|
#model.save(dir_save, include_optimizer=False)
|
||||||
|
model.export(dir_save)
|
||||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from shapely.geometry.polygon import orient
|
||||||
from shapely import set_precision, affinity
|
from shapely import set_precision, affinity
|
||||||
from shapely.ops import unary_union, nearest_points
|
from shapely.ops import unary_union, nearest_points
|
||||||
|
|
||||||
from .rotate import rotate_image, rotation_image_new
|
from .rotate import rotate_image
|
||||||
|
|
||||||
def contours_in_same_horizon(cy_main_hor):
|
def contours_in_same_horizon(cy_main_hor):
|
||||||
"""
|
"""
|
||||||
|
|
@ -120,94 +120,6 @@ def return_contours_of_interested_region(region_pre_p, label, min_area=0.0002, d
|
||||||
dilate=dilate)
|
dilate=dilate)
|
||||||
return contours_imgs
|
return contours_imgs
|
||||||
|
|
||||||
def do_work_of_contours_in_image(contour, index_r_con, img, slope_first):
|
|
||||||
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
|
|
||||||
img_copy = cv2.fillPoly(img_copy, pts=[contour], color=1)
|
|
||||||
|
|
||||||
img_copy = rotation_image_new(img_copy, -slope_first)
|
|
||||||
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
|
|
||||||
|
|
||||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
|
||||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
|
||||||
|
|
||||||
return cont_int[0], index_r_con
|
|
||||||
|
|
||||||
def get_textregion_contours_in_org_image_multi(cnts, img, slope_first, map=map):
|
|
||||||
if not len(cnts):
|
|
||||||
return [], []
|
|
||||||
results = map(partial(do_work_of_contours_in_image,
|
|
||||||
img=img,
|
|
||||||
slope_first=slope_first,
|
|
||||||
),
|
|
||||||
cnts, range(len(cnts)))
|
|
||||||
return tuple(zip(*results))
|
|
||||||
|
|
||||||
def get_textregion_contours_in_org_image(cnts, img, slope_first):
|
|
||||||
cnts_org = []
|
|
||||||
# print(cnts,'cnts')
|
|
||||||
for i in range(len(cnts)):
|
|
||||||
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
|
|
||||||
img_copy = cv2.fillPoly(img_copy, pts=[cnts[i]], color=1)
|
|
||||||
|
|
||||||
# plt.imshow(img_copy)
|
|
||||||
# plt.show()
|
|
||||||
|
|
||||||
# print(img.shape,'img')
|
|
||||||
img_copy = rotation_image_new(img_copy, -slope_first)
|
|
||||||
##print(img_copy.shape,'img_copy')
|
|
||||||
# plt.imshow(img_copy)
|
|
||||||
# plt.show()
|
|
||||||
|
|
||||||
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
|
|
||||||
|
|
||||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
|
||||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
|
||||||
# print(np.shape(cont_int[0]))
|
|
||||||
cnts_org.append(cont_int[0])
|
|
||||||
|
|
||||||
return cnts_org
|
|
||||||
|
|
||||||
def get_textregion_confidences_old(cnts, img, slope_first):
|
|
||||||
zoom = 3
|
|
||||||
img = cv2.resize(img, (img.shape[1] // zoom,
|
|
||||||
img.shape[0] // zoom),
|
|
||||||
interpolation=cv2.INTER_NEAREST)
|
|
||||||
cnts_org = []
|
|
||||||
for cnt in cnts:
|
|
||||||
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
|
|
||||||
img_copy = cv2.fillPoly(img_copy, pts=[cnt // zoom], color=1)
|
|
||||||
|
|
||||||
img_copy = rotation_image_new(img_copy, -slope_first).astype(np.uint8)
|
|
||||||
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
|
|
||||||
|
|
||||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
|
||||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
|
||||||
cnts_org.append(cont_int[0] * zoom)
|
|
||||||
|
|
||||||
return cnts_org
|
|
||||||
|
|
||||||
def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first, confidence_matrix):
|
|
||||||
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
|
|
||||||
img_copy = cv2.fillPoly(img_copy, pts=[contour_par], color=1)
|
|
||||||
confidence_matrix_mapped_with_contour = confidence_matrix * img_copy
|
|
||||||
confidence_contour = np.sum(confidence_matrix_mapped_with_contour) / float(np.sum(img_copy))
|
|
||||||
|
|
||||||
img_copy = rotation_image_new(img_copy, -slope_first).astype(np.uint8)
|
|
||||||
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
|
|
||||||
|
|
||||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
if len(cont_int)==0:
|
|
||||||
cont_int = [contour_par]
|
|
||||||
confidence_contour = 0
|
|
||||||
else:
|
|
||||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
|
||||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
|
||||||
return cont_int[0], index_r_con, confidence_contour
|
|
||||||
|
|
||||||
def get_region_confidences(cnts, confidence_matrix):
|
def get_region_confidences(cnts, confidence_matrix):
|
||||||
if not len(cnts):
|
if not len(cnts):
|
||||||
return []
|
return []
|
||||||
|
|
@ -418,7 +330,7 @@ def estimate_skew_contours(contours):
|
||||||
if not np.any(usable):
|
if not np.any(usable):
|
||||||
raise ValueError("not enough contours with consistent length")
|
raise ValueError("not enough contours with consistent length")
|
||||||
if np.count_nonzero(usable) == 1:
|
if np.count_nonzero(usable) == 1:
|
||||||
return angle_in[usable]
|
return angle_in[usable][0]
|
||||||
# 4. there is no way to distinguish between +90 and -89.9 here,
|
# 4. there is no way to distinguish between +90 and -89.9 here,
|
||||||
# so map to [0,180] when calculating averages, then map back to [-90,90]
|
# so map to [0,180] when calculating averages, then map back to [-90,90]
|
||||||
# (we don't want -90 and +89 to average zero, or +1 and +179 to average 90)
|
# (we don't want -90 and +89 to average zero, or +1 and +179 to average 90)
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,6 @@ import math
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def rotation_image_new(img, thetha):
|
|
||||||
rotated = rotate_image(img, thetha)
|
|
||||||
return rotate_max_area_new(img, rotated, thetha)
|
|
||||||
|
|
||||||
def rotate_image(img_patch, slope):
|
def rotate_image(img_patch, slope):
|
||||||
(h, w) = img_patch.shape[:2]
|
(h, w) = img_patch.shape[:2]
|
||||||
center = (w // 2, h // 2)
|
center = (w // 2, h // 2)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import math
|
import math
|
||||||
import copy
|
import copy
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import tensorflow as tf
|
# avoid module-level import:
|
||||||
|
# import tensorflow as tf
|
||||||
|
# (wait for tf-keras and logging setup in ModelZoo.load_model)
|
||||||
from scipy.signal import find_peaks
|
from scipy.signal import find_peaks
|
||||||
from scipy.ndimage import gaussian_filter1d
|
from scipy.ndimage import gaussian_filter1d
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
@ -12,6 +15,8 @@ from .resize import resize_image
|
||||||
|
|
||||||
|
|
||||||
def decode_batch_predictions(pred, num_to_char, max_len = 128):
|
def decode_batch_predictions(pred, num_to_char, max_len = 128):
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
# input_len is the product of the batch size and the
|
# input_len is the product of the batch size and the
|
||||||
# number of time steps.
|
# number of time steps.
|
||||||
input_len = np.ones(pred.shape[0]) * pred.shape[1]
|
input_len = np.ones(pred.shape[0]) * pred.shape[1]
|
||||||
|
|
@ -39,6 +44,8 @@ def decode_batch_predictions(pred, num_to_char, max_len = 128):
|
||||||
|
|
||||||
|
|
||||||
def distortion_free_resize(image, img_size):
|
def distortion_free_resize(image, img_size):
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
w, h = img_size
|
w, h = img_size
|
||||||
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
|
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
|
||||||
|
|
||||||
|
|
@ -502,3 +509,8 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
|
||||||
ocr_textline_in_textregion.append(text_textline)
|
ocr_textline_in_textregion.append(text_textline)
|
||||||
ocr_all_textlines.append(ocr_textline_in_textregion)
|
ocr_all_textlines.append(ocr_textline_in_textregion)
|
||||||
return ocr_all_textlines
|
return ocr_all_textlines
|
||||||
|
|
||||||
|
def batched(iterable, n):
|
||||||
|
iterator = iter(iterable)
|
||||||
|
while batch := tuple(islice(iterator, n)):
|
||||||
|
yield batch
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
@ -31,6 +32,8 @@ def run_eynollah_ok_and_check_logs(
|
||||||
subcommand,
|
subcommand,
|
||||||
*args
|
*args
|
||||||
]
|
]
|
||||||
|
if 'EYNOLLAH_OPTIONS' in os.environ:
|
||||||
|
args = os.environ['EYNOLLAH_OPTIONS'].split() + args
|
||||||
if pytestconfig.getoption('verbose') > 0:
|
if pytestconfig.getoption('verbose') > 0:
|
||||||
args = ['-l', 'DEBUG'] + args
|
args = ['-l', 'DEBUG'] + args
|
||||||
caplog.set_level(logging.INFO)
|
caplog.set_level(logging.INFO)
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,12 @@ from ocrd_models.constants import NAMESPACES as NS
|
||||||
"options",
|
"options",
|
||||||
[
|
[
|
||||||
[], # defaults
|
[], # defaults
|
||||||
#["--allow_scaling", "--curved-line"],
|
#["--curved-line"],
|
||||||
["--allow_scaling", "--curved-line", "--full-layout"],
|
["--curved-line", "--full-layout"],
|
||||||
["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"],
|
["--curved-line", "--full-layout", "--reading_order_machine_based"],
|
||||||
# -ep ...
|
# -ep ...
|
||||||
# -eoi ...
|
# --input_binary
|
||||||
|
# --ignore_page_extraction
|
||||||
# --skip_layout_and_reading_order
|
# --skip_layout_and_reading_order
|
||||||
], ids=str)
|
], ids=str)
|
||||||
def test_run_eynollah_layout_filename(
|
def test_run_eynollah_layout_filename(
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ def test_run_eynollah_ocr_filename(
|
||||||
'-o', str(outfile.parent),
|
'-o', str(outfile.parent),
|
||||||
] + options,
|
] + options,
|
||||||
[
|
[
|
||||||
# FIXME: ocr has no logging!
|
'output filename:'
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert outfile.exists()
|
assert outfile.exists()
|
||||||
|
|
@ -57,7 +57,7 @@ def test_run_eynollah_ocr_directory(
|
||||||
'-o', str(outdir),
|
'-o', str(outdir),
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
# FIXME: ocr has no logging!
|
'output filename:'
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert len(list(outdir.iterdir())) == 2
|
assert len(list(outdir.iterdir())) == 2
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,10 @@ def test_trocr1(
|
||||||
model_zoo = EynollahModelZoo(model_dir)
|
model_zoo = EynollahModelZoo(model_dir)
|
||||||
try:
|
try:
|
||||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||||
model_zoo.load_models('trocr_processor')
|
model_zoo.load_models('trocr_processor',
|
||||||
|
('ocr', 'tr'))
|
||||||
proc = model_zoo.get('trocr_processor')
|
proc = model_zoo.get('trocr_processor')
|
||||||
assert isinstance(proc, TrOCRProcessor)
|
assert isinstance(proc, TrOCRProcessor)
|
||||||
model_zoo.load_models(['ocr', 'tr'])
|
|
||||||
model = model_zoo.get('ocr')
|
model = model_zoo.get('ocr')
|
||||||
assert isinstance(model, VisionEncoderDecoderModel)
|
assert isinstance(model, VisionEncoderDecoderModel)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue