Merge remote-tracking branch 'bertsky/fix-0.8-modelzoo-and-predictor' into integrating_trocr_and_torch_ensembling_and_updating_characters_list

# Conflicts:
#	src/eynollah/eynollah_ocr.py
This commit is contained in:
kba 2026-06-11 19:07:19 +02:00
commit 28a559c710
32 changed files with 1132 additions and 1418 deletions

View file

@ -1,7 +1,3 @@
# NOTE: For predictable order of imports of torch/shapely/tensorflow
# this must be the first import of the CLI!
from ..eynollah_imports import imported_libs
from .cli import main
from .cli_binarize import binarize_cli
from .cli_enhance import enhance_cli

View file

@ -15,6 +15,7 @@ class EynollahCliCtx:
Holds options relevant for all eynollah subcommands
"""
model_zoo: EynollahModelZoo
device: str = ''
log_level : Union[str, None] = 'INFO'
@ -35,6 +36,11 @@ class EynollahCliCtx:
type=(str, str, str),
multiple=True,
)
@click.option(
"--device",
"-D",
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
)
@click.option(
"--log_level",
"-l",
@ -42,7 +48,7 @@ class EynollahCliCtx:
help="Override log level globally to this",
)
@click.pass_context
def main(ctx, model_basedir, model_overrides, log_level):
def main(ctx, model_basedir, model_overrides, device, log_level):
"""
eynollah - Document Layout Analysis, Image Enhancement, OCR
"""
@ -58,6 +64,7 @@ def main(ctx, model_basedir, model_overrides, log_level):
# Initialize CLI context
ctx.obj = EynollahCliCtx(
model_zoo=model_zoo,
device=device,
log_level=log_level,
)

View file

@ -1,6 +1,8 @@
import click
@click.command()
@click.command(context_settings=dict(
help_option_names=['-h', '--help'],
show_default=True))
@click.option(
'--patches/--no-patches',
default=True,
@ -31,11 +33,6 @@ import click
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--device",
"-D",
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
)
@click.pass_context
def binarize_cli(
ctx,
@ -44,14 +41,14 @@ def binarize_cli(
dir_in,
output,
overwrite,
device,
):
"""
Binarize images with a ML model
"""
from ..sbb_binarize import SbbBinarizer
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo, device=device)
binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo,
device=ctx.obj.device)
binarizer.run(
image_filename=input_image,
use_patches=patches,

View file

@ -1,6 +1,8 @@
import click
@click.command()
@click.command(context_settings=dict(
help_option_names=['-h', '--help'],
show_default=True))
@click.option(
"--image",
"-i",
@ -46,13 +48,8 @@ import click
is_flag=True,
help="save the enhanced image in original image size",
)
@click.option(
"--device",
"-D",
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
)
@click.pass_context
def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale, device):
def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale):
"""
Enhance image
"""
@ -60,10 +57,10 @@ def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower
from ..image_enhancer import Enhancer
enhancer = Enhancer(
model_zoo=ctx.obj.model_zoo,
device=ctx.obj.device,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
save_org_scale=save_org_scale,
device=device,
)
enhancer.run(overwrite=overwrite,
dir_in=dir_in,

View file

@ -1,6 +1,8 @@
import click
@click.command()
@click.command(context_settings=dict(
help_option_names=['-h', '--help'],
show_default=True))
@click.option(
"--image",
"-i",
@ -30,36 +32,40 @@ import click
@click.option(
"--save_images",
"-si",
help="if a directory is given, images in documents will be cropped and saved there",
help="if a directory is given, cropped images of pages will be saved there",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--enable-plotting/--disable-plotting",
"-ep/-noep",
"--enable-plotting",
"-ep",
is_flag=True,
help="If set, will plot intermediary files and images",
help="plot intermediary diagnostic images to files",
)
@click.option(
"--input_binary/--input-RGB",
"-ib/-irgb",
"--input_binary",
"-ib",
is_flag=True,
help="In general, eynollah uses RGB as input but if the input document is very dark, very bright or for any other reason you can turn on input binarization. When this flag is set, eynollah will binarize the RGB input document, you should always provide RGB images to eynollah.",
help="In general, eynollah uses RGB as input, but if the input document is very dark, very bright or for any other reason you can turn on internal binarization here. When set, eynollah will binarize the RGB input document first.",
)
@click.option(
"--ignore_page_extraction/--extract_page_included",
"-ipe/-epi",
"--ignore_page_extraction",
"-ipe",
is_flag=True,
help="if this parameter set to true, this tool would ignore page extraction",
help="ignore page extraction (cropping via page frame detection model)",
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
default=0,
type=click.IntRange(min=0),
help="lower limit of columns in document image; 0 means autodetected from model",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
default=0,
type=click.IntRange(min=0),
help="upper limit of columns in document image; 0 means autodetected from model",
)
@click.pass_context
def extract_images_cli(

View file

@ -172,11 +172,6 @@ import click
type=click.FloatRange(min=0),
help="abort when number of failed images exceeds this value (if >=1) or ratio of failed over total images exceeds this value (if <1); 0 means ignore failures",
)
@click.option(
"--device",
"-D",
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
)
@click.pass_context
def layout_cli(
ctx,
@ -207,7 +202,6 @@ def layout_cli(
ignore_page_extraction,
num_jobs,
halt_fail,
device,
):
"""
Detect Layout (with optional image enhancement and reading order detection)
@ -223,7 +217,7 @@ def layout_cli(
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah(
model_zoo=ctx.obj.model_zoo,
device=device,
device=ctx.obj.device,
enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement,
curved_line=curved_line,

View file

@ -1,6 +1,8 @@
import click
@click.command()
@click.command(context_settings=dict(
help_option_names=['-h', '--help'],
show_default=True))
@click.option(
"--image",
"-i",
@ -16,7 +18,7 @@ import click
@click.option(
"--dir_in_bin",
"-dib",
help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png' \n Perform prediction using both RGB and binary images. (This does not necessarily improve results, however it may be beneficial for certain document images."),
help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png'. \n Perform prediction using both RGB and binary images. (This may improve results for certain document images.)"),
type=click.Path(exists=True, file_okay=False),
)
@click.option(
@ -47,25 +49,29 @@ import click
)
@click.option(
"--tr_ocr",
"-trocr/-notrocr",
"-trocr",
is_flag=True,
help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.",
help="use transformer OCR (instead of classic CNN-RNN) model",
)
@click.option(
"--do_not_mask_with_textline_contour",
"-nmtc/-mtc",
"-nmtc",
is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
help="skip masking each cropped textline image with its corresponding textline contour",
)
@click.option(
"--batch_size",
"-bs",
default=0,
type=click.IntRange(min=0),
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
)
@click.option(
"--min_conf_value_of_textline_text",
"-min_conf",
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
default=0.3,
type=click.FloatRange(min=0.0, max=1.0),
help="minimum OCR confidence threshold. Text lines with a lower confidence value will not be included in the output XML file.",
)
@click.pass_context
def ocr_cli(
@ -85,14 +91,16 @@ def ocr_cli(
"""
Recognize text with a CNN/RNN or transformer ML model.
"""
assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
from ..eynollah_ocr import Eynollah_ocr
eynollah_ocr = Eynollah_ocr(
model_zoo=ctx.obj.model_zoo,
device=ctx.obj.device,
tr_ocr=tr_ocr,
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
batch_size=batch_size,
min_conf_value_of_textline_text=min_conf_value_of_textline_text)
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
)
eynollah_ocr.run(overwrite=overwrite,
dir_in=dir_in,
dir_in_bin=dir_in_bin,

View file

@ -1,6 +1,8 @@
import click
@click.command()
@click.command(context_settings=dict(
help_option_names=['-h', '--help'],
show_default=True))
@click.option(
"--input",
"-i",
@ -25,9 +27,10 @@ def readingorder_cli(ctx, input, dir_in, out):
"""
Generate ReadingOrder with a ML model
"""
from ..mb_ro_on_layout import machine_based_reading_order_on_layout
from ..mb_ro_on_layout import Reorder
assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
orderer = machine_based_reading_order_on_layout(model_zoo=ctx.obj.model_zoo)
orderer = Reorder(model_zoo=ctx.obj.model_zoo,
device=ctx.obj.device)
orderer.run(xml_filename=input,
dir_in=dir_in,
dir_out=out,

View file

@ -9,7 +9,6 @@ import os
import time
from typing import Optional
from pathlib import Path
import tensorflow as tf
import numpy as np
import cv2
@ -64,12 +63,6 @@ class EynollahImageExtractor(Eynollah):
t_start = time.time()
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
self.logger.info("Loading models...")
self.setup_models()
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")

View file

@ -1148,7 +1148,6 @@ class Eynollah:
boxes,
textline_mask_tot
):
assert np.any(textline_mask_tot)
self.logger.debug("enter do_order_of_regions")
contours_only_text_parent = ensure_array(contours_only_text_parent)
contours_only_text_parent_h = ensure_array(contours_only_text_parent_h)

View file

@ -1,13 +0,0 @@
"""
Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
"""
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
from torch import *
tf_disable_interactive_logs()
import tensorflow.keras
from shapely import *
imported_libs = True
__all__ = ['imported_libs']

View file

@ -14,20 +14,21 @@ from cv2.typing import MatLike
from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw
import numpy as np
from eynollah.model_zoo import EynollahModelZoo
from eynollah.utils.font import get_font
from eynollah.utils.xml import etree_namespace_for_element_tag
try:
import torch
except ImportError:
torch = None
from ocrd_utils import polygon_from_points, xywh_from_polygon
from .utils import is_image_filename
from .eynollah import Eynollah
from .model_zoo import EynollahModelZoo
from .utils import (
is_image_filename,
batched,
pairwise,
)
from .utils.font import get_font
from .utils.xml import etree_namespace_for_element_tag
from .utils.resize import resize_image
from .utils.utils_ocr import (
break_curved_line_into_small_pieces_and_then_merge,
decode_batch_predictions,
fit_text_single_line,
get_contours_and_bounding_boxes,
get_orientation_moments,
@ -40,50 +41,45 @@ from .utils.utils_ocr import (
@dataclass
class EynollahOcrResult:
extracted_texts_merged: List
extracted_conf_value_merged: Optional[List]
extracted_confs_merged: List
cropped_lines_region_indexer: List
total_bb_coordinates:List
class Eynollah_ocr:
class Eynollah_ocr(Eynollah):
def __init__(
self,
*,
model_zoo: EynollahModelZoo,
tr_ocr=False,
batch_size: Optional[int]=None,
batch_size: int=0,
do_not_mask_with_textline_contour: bool=False,
min_conf_value_of_textline_text : Optional[float]=None,
min_conf_value_of_textline_text : float=0.3,
logger: Optional[Logger]=None,
device: str = '',
):
self.tr_ocr = tr_ocr
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
self.logger = logger if logger else getLogger('eynollah.ocr')
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
self.b_s = batch_size or 2 if tr_ocr else 8
self.model_zoo = model_zoo
self.setup_models(device=device)
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
if tr_ocr:
self.model_zoo.load_models('trocr_processor')
self.model_zoo.load_models(['ocr', 'tr'])
self.model_zoo.get('ocr').to(self.device)
def setup_models(self, device=''):
if self.tr_ocr:
self.model_zoo.load_models(('ocr', 'tr'),
device=device)
else:
self.model_zoo.load_model('ocr', '')
self.input_shape = self.model_zoo.get('ocr').input_shape[1:3]
self.model_zoo.load_model('num_to_char')
self.model_zoo.load_model('characters')
self.end_character = len(self.model_zoo.get('characters', list)) + 2
self.model_zoo.load_models('ocr',
'binarization',
device=device)
@property
def device(self):
assert torch
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
return torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
return torch.device("cpu")
return self.model_zoo.get('ocr').device
def run_trocr(
self,
@ -91,178 +87,73 @@ class Eynollah_ocr:
img: MatLike,
page_tree: ET.ElementTree,
page_ns,
tr_ocr_input_height_and_width,
) -> EynollahOcrResult:
total_bb_coordinates = []
cropped_lines = []
cropped_lines_region_indexer = []
cropped_lines_meging_indexing = []
extracted_texts = []
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
cropped_lines_region_indexer.append(n_region)
indexer_text_region = 0
indexer_b_s = 0
coords = line.find('{%s}Coords' % page_ns)
if coords is None:
self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
continue
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
cont = poly[:, np.newaxis]
xywh = xywh_from_polygon(poly)
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
total_bb_coordinates.append([x, y, w, h])
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h=child_textlines.attrib['points'].split(' ')
textline_coords = np.array( [ [int(x.split(',')[0]),
int(x.split(',')[1]) ]
for x in p_h] )
x,y,w,h = cv2.boundingRect(textline_coords)
img_crop = img[y: y + h, x: x + w]
if not self.do_not_mask_with_textline_contour:
mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8)
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
img_crop[mask_poly == 0] = 255 # FIXME: or median color?
total_bb_coordinates.append([x,y,w,h])
h2w_ratio = h/float(w)
img_poly_on_img = np.copy(img)
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y:y+h, x:x+w, :]
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
img_crop[mask_poly==0] = 255
self.logger.debug("processing %d lines for '%s'",
len(cropped_lines), nn.attrib['id'])
if h2w_ratio > 0.1:
cropped_lines.append(resize_image(img_crop,
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width) )
if h > 0.1 * w:
cropped_lines.append(img_crop)
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
else:
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
#print(splited_images)
if splited_images:
cropped_lines.append(resize_image(splited_images[0],
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width))
cropped_lines.append(splited_images[0])
cropped_lines.append(splited_images[1])
cropped_lines_meging_indexing.append(1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
cropped_lines.append(resize_image(splited_images[1],
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(-1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
else:
cropped_lines.append(img_crop)
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
indexer_text_region = indexer_text_region +1
if indexer_b_s!=0:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
####extracted_texts = []
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
####for i in range(n_iterations):
####if i==(n_iterations-1):
####n_start = i*self.b_s
####imgs = cropped_lines[n_start:]
####else:
####n_start = i*self.b_s
####n_end = (i+1)*self.b_s
####imgs = cropped_lines[n_start:n_end]
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
####generated_ids_merged = self.model_ocr.generate(
#### pixel_values_merged.to(self.device))
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
#### generated_ids_merged, skip_special_tokens=True)
####extracted_texts = extracted_texts + generated_text_merged
extracted_texts = []
extracted_confs = []
self.logger.debug("processing %d lines for %d regions",
len(cropped_lines), len(set(cropped_lines_region_indexer)))
for imgs in batched(cropped_lines, self.b_s):
text, conf = self.model_zoo.get('ocr').predict(imgs)
extracted_confs.extend(conf)
extracted_texts.extend(text)
del cropped_lines
gc.collect()
extracted_texts_merged = [extracted_texts[ind]
if cropped_lines_meging_indexing[ind]==0
else extracted_texts[ind]+" "+extracted_texts[ind+1]
if cropped_lines_meging_indexing[ind]==1
else None
for ind in range(len(cropped_lines_meging_indexing))]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
#print(extracted_texts_merged, len(extracted_texts_merged))
if cropped_lines_meging_indexing[ind] == 0
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
for ind in range(len(cropped_lines_meging_indexing))
if cropped_lines_meging_indexing[ind] >= 0]
extracted_confs_merged = [extracted_confs[ind]
if cropped_lines_meging_indexing[ind] == 0
else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1])
for ind in range(len(cropped_lines_meging_indexing))
if cropped_lines_meging_indexing[ind] >= 0]
return EynollahOcrResult(
extracted_texts_merged=extracted_texts_merged,
extracted_conf_value_merged=None,
extracted_confs_merged=extracted_confs_merged,
cropped_lines_region_indexer=cropped_lines_region_indexer,
total_bb_coordinates=total_bb_coordinates,
)
@ -274,367 +165,163 @@ class Eynollah_ocr:
img_bin: Optional[MatLike],
page_tree: ET.ElementTree,
page_ns,
image_width,
image_height,
) -> EynollahOcrResult:
_, image_height, image_width, _ = self.model_zoo.get('ocr').input_shape
total_bb_coordinates = []
cropped_lines = []
img_crop_bin = None
imgs_bin = None
imgs_bin_ver_flipped = None
cropped_lines_rgb = []
cropped_lines_bin = []
cropped_lines_ver_index = []
cropped_lines_region_indexer = []
cropped_lines_meging_indexing = []
indexer_text_region = 0
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
try:
type_textregion = nn.attrib['type']
except:
type_textregion = 'paragraph'
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h=child_textlines.attrib['points'].split(' ')
textline_coords = np.array( [ [int(x.split(',')[0]),
int(x.split(',')[1]) ]
for x in p_h] )
img_rgb = img # cosmetic
if img_bin is None:
# run ad-hoc binarization
self.logger.info("running binarization for ensemble input")
img_bin = self.do_prediction(True, img, self.model_zoo.get("binarization"),
n_batch_inference=5)
img_bin = np.repeat(img_bin[:, :, np.newaxis], 3, axis=2)
img_bin = 255 * (img_bin == 0).astype(np.uint8)
x,y,w,h = cv2.boundingRect(textline_coords)
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
type_textregion = region.attrib.get('type', 'paragraph')
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
cropped_lines_region_indexer.append(n_region)
coords = line.find('{%s}Coords' % page_ns)
if coords is None:
self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
continue
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
cont = poly[:, np.newaxis]
xywh = xywh_from_polygon(poly)
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
angle_radians = math.atan2(h, w)
# Convert to degrees
angle_degrees = math.degrees(angle_radians)
if type_textregion=='drop-capital':
angle_degrees = 0
total_bb_coordinates.append([x,y,w,h])
total_bb_coordinates.append([x, y, w, h])
w_scaled = w * image_height/float(h)
w_scaled = w * image_height / float(h)
img_poly_on_img = np.copy(img)
if img_bin:
img_poly_on_img_bin = np.copy(img_bin)
img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :]
img_crop_rgb = img_rgb[y: y + h, x: x + w]
img_crop_bin = img_bin[y: y + h, x: x + w]
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y:y+h, x:x+w, :]
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
# print(file_name, angle_degrees, w*h,
# mask_poly[:,:,0].sum(),
# mask_poly[:,:,0].sum() /float(w*h) ,
# 'didi')
mask_poly = np.zeros(img_crop_rgb.shape[:2], dtype=np.uint8)
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
if angle_degrees > 3:
better_des_slope = get_orientation_moments(textline_coords)
img_crop = rotate_image_with_padding(img_crop, better_des_slope)
if img_bin:
better_des_slope = get_orientation_moments(cont)
img_crop_rgb = rotate_image_with_padding(img_crop_rgb, better_des_slope)
img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope)
mask_poly = rotate_image_with_padding(mask_poly, better_des_slope)
mask_poly = mask_poly.astype('uint8')
#new bounding box
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0])
mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :]
img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :]
if not self.do_not_mask_with_textline_contour:
img_crop[mask_poly==0] = 255
if img_bin:
img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :]
if not self.do_not_mask_with_textline_contour:
img_crop_bin[mask_poly==0] = 255
if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90:
if img_bin:
img_crop, img_crop_bin = \
break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly, img_crop_bin)
else:
img_crop, _ = \
break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly)
# get new bounding box
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly)
img_crop_rgb = img_crop_rgb[y_n: y_n + h_n, x_n: x_n + w_n]
img_crop_bin = img_crop_bin[y_n: y_n + h_n, x_n: x_n + w_n]
mask_poly = mask_poly[y_n: y_n + h_n, x_n: x_n + w_n]
else:
better_des_slope = 0
if not self.do_not_mask_with_textline_contour:
img_crop[mask_poly==0] = 255
if img_bin:
if not self.do_not_mask_with_textline_contour:
img_crop_bin[mask_poly==0] = 255
if type_textregion=='drop-capital':
pass
else:
if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90:
if img_bin:
img_crop, img_crop_bin = \
img_crop_rgb[mask_poly == 0] = 255 # FIXME: or median color?
img_crop_bin[mask_poly == 0] = 255
if (type_textregion !='drop-capital' and
mask_poly.sum() < 0.50 * mask_poly.size and
w_scaled > 90):
img_crop_rgb, img_crop_bin = \
break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly, img_crop_bin)
else:
img_crop, _ = \
break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly)
img_crop_rgb, img_crop_bin, mask_poly)
if w_scaled < 750:#1.5*image_width:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop, image_height, image_width)
cropped_lines.append(img_fin)
img_crop_split_rgb = img_crop_split_bin = None
else:
img_crop_split_rgb, img_crop_split_bin = return_textlines_split_if_needed(
img_crop_rgb, img_crop_bin)
if img_crop_split_rgb:
cropped_lines_rgb.extend(img_crop_split_rgb)
cropped_lines_bin.extend(img_crop_split_bin)
if abs(better_des_slope) > 45:
cropped_lines_ver_index.append(1)
cropped_lines_ver_index.append(1)
else:
cropped_lines_ver_index.append(0)
cropped_lines_meging_indexing.append(0)
if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop_bin, image_height, image_width)
cropped_lines_bin.append(img_fin)
else:
splited_images, splited_images_bin = return_textlines_split_if_needed(
img_crop, img_crop_bin if img_bin else None)
if splited_images:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images[0], image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_ver_index.append(0)
cropped_lines_meging_indexing.append(1)
if abs(better_des_slope) > 45:
cropped_lines_ver_index.append(1)
else:
cropped_lines_ver_index.append(0)
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images[1], image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(-1)
else:
cropped_lines_rgb.append(img_crop_rgb)
cropped_lines_bin.append(img_crop_bin)
if abs(better_des_slope) > 45:
cropped_lines_ver_index.append(1)
else:
cropped_lines_ver_index.append(0)
if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images_bin[0], image_height, image_width)
cropped_lines_bin.append(img_fin)
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images_bin[1], image_height, image_width)
cropped_lines_bin.append(img_fin)
else:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop, image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(0)
if abs(better_des_slope) > 45:
cropped_lines_ver_index.append(1)
else:
cropped_lines_ver_index.append(0)
if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop_bin, image_height, image_width)
cropped_lines_bin.append(img_fin)
indexer_text_region = indexer_text_region +1
cropped_lines_rgb = [preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width)
for img in cropped_lines_rgb]
cropped_lines_bin = [preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width)
for img in cropped_lines_bin]
extracted_texts = []
extracted_conf_value = []
extracted_confs = []
self.logger.debug("processing %d lines for %d regions",
len(cropped_lines_rgb), len(set(cropped_lines_region_indexer)))
cropped_lines = zip(cropped_lines_rgb, cropped_lines_bin, cropped_lines_ver_index)
for batch in batched(cropped_lines, self.b_s):
imgs_rgb, imgs_bin, ver_index = zip(*batch)
ver_index = np.array(ver_index)
imgs_rgb = np.stack(imgs_rgb)
imgs_bin = np.stack(imgs_bin)
imgs_rgb_ver = imgs_rgb[ver_index > 0, ::-1, ::-1]
imgs_bin_ver = imgs_bin[ver_index > 0, ::-1, ::-1]
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
# inference model now yields (char-bytes, line-prob) instead of vocidx-softmax
# (so ctc_decode and inverse StringLookup are included)
# also, the model now expects a secondary binary input image
preds, probs = self.model_zoo.get('ocr').predict((imgs_rgb, imgs_bin), verbose=0)
# FIXME: copy pasta
for i in range(n_iterations):
if i==(n_iterations-1):
n_start = i*self.b_s
imgs = cropped_lines[n_start:]
imgs = np.array(imgs)
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
if ver_index.any():
preds_ver, probs_ver = self.model_zoo.get('ocr').predict((imgs_rgb_ver, imgs_bin_ver), verbose=0)
flipped_ver_is_better = np.flatnonzero(probs_ver > probs[ver_index > 0])
if len(flipped_ver_is_better):
self.logger.info("%d skewed lines perform better when flipped", len(flipped_ver_is_better))
preds[ver_index > 0][flipped_ver_is_better] = preds_ver[flipped_ver_is_better]
probs[ver_index > 0][flipped_ver_is_better] = probs_ver[flipped_ver_is_better]
ver_imgs = np.array( cropped_lines_ver_index[n_start:] )
indices_ver = np.where(ver_imgs == 1)[0]
#print(indices_ver, 'indices_ver')
if len(indices_ver)>0:
imgs_ver_flipped = imgs[indices_ver, : ,: ,:]
imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:]
#print(imgs_ver_flipped, 'imgs_ver_flipped')
else:
imgs_ver_flipped = None
if img_bin:
imgs_bin = cropped_lines_bin[n_start:]
imgs_bin = np.array(imgs_bin)
imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3)
if len(indices_ver)>0:
imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:]
imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:]
#print(imgs_ver_flipped, 'imgs_ver_flipped')
else:
imgs_bin_ver_flipped = None
else:
n_start = i*self.b_s
n_end = (i+1)*self.b_s
imgs = cropped_lines[n_start:n_end]
imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3)
ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] )
indices_ver = np.where(ver_imgs == 1)[0]
#print(indices_ver, 'indices_ver')
if len(indices_ver)>0:
imgs_ver_flipped = imgs[indices_ver, : ,: ,:]
imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:]
#print(imgs_ver_flipped, 'imgs_ver_flipped')
else:
imgs_ver_flipped = None
if img_bin:
imgs_bin = cropped_lines_bin[n_start:n_end]
imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3)
if len(indices_ver)>0:
imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:]
imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:]
#print(imgs_ver_flipped, 'imgs_ver_flipped')
else:
imgs_bin_ver_flipped = None
self.logger.debug("processing next %d lines", len(imgs))
preds = self.model_zoo.get('ocr').predict(imgs, verbose=0)
if len(indices_ver)>0:
preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0)
preds_max_fliped = np.max(preds_flipped, axis=2 )
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
masked_means_flipped = \
np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \
np.sum(pred_max_not_unk_mask_bool_flipped, axis=1)
masked_means_flipped[np.isnan(masked_means_flipped)] = 0
preds_max = np.max(preds, axis=2 )
preds_max_args = np.argmax(preds, axis=2 )
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
masked_means = \
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
np.sum(pred_max_not_unk_mask_bool, axis=1)
masked_means[np.isnan(masked_means)] = 0
masked_means_ver = masked_means[indices_ver]
#print(masked_means_ver, 'pred_max_not_unk')
indices_where_flipped_conf_value_is_higher = \
np.where(masked_means_flipped > masked_means_ver)[0]
#print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher')
if len(indices_where_flipped_conf_value_is_higher)>0:
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
preds[indices_to_be_replaced,:,:] = \
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
if img_bin:
preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0)
if len(indices_ver)>0:
preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0)
preds_max_fliped = np.max(preds_flipped, axis=2 )
preds_max_args_flipped = np.argmax(preds_flipped, axis=2 )
pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character
masked_means_flipped = \
np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \
np.sum(pred_max_not_unk_mask_bool_flipped, axis=1)
masked_means_flipped[np.isnan(masked_means_flipped)] = 0
preds_max = np.max(preds, axis=2 )
preds_max_args = np.argmax(preds, axis=2 )
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
masked_means = \
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
np.sum(pred_max_not_unk_mask_bool, axis=1)
masked_means[np.isnan(masked_means)] = 0
masked_means_ver = masked_means[indices_ver]
#print(masked_means_ver, 'pred_max_not_unk')
indices_where_flipped_conf_value_is_higher = \
np.where(masked_means_flipped > masked_means_ver)[0]
#print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher')
if len(indices_where_flipped_conf_value_is_higher)>0:
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
preds_bin[indices_to_be_replaced,:,:] = \
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
preds = (preds + preds_bin) / 2.
pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char'))
preds_max = np.max(preds, axis=2 )
preds_max_args = np.argmax(preds, axis=2 )
pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character
masked_means = \
np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \
np.sum(pred_max_not_unk_mask_bool, axis=1)
for ib in range(imgs.shape[0]):
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
if masked_means[ib] >= self.min_conf_value_of_textline_text:
extracted_texts.append(pred_texts_ib)
extracted_conf_value.append(masked_means[ib])
else:
extracted_texts.append("")
extracted_conf_value.append(0)
del cropped_lines
def nooov(x):
return x != b'[UNK]'
for pred, prob in zip(preds, probs):
text = b''.join(
filter(nooov,
map(bytes,
(filter(None, char)
for char in pred.tolist())))).decode('utf-8')
extracted_texts.append(text)
extracted_confs.append(prob)
del cropped_lines_rgb
del cropped_lines_bin
gc.collect()
extracted_texts_merged = [extracted_texts[ind]
if cropped_lines_meging_indexing[ind]==0
else extracted_texts[ind]+" "+extracted_texts[ind+1]
if cropped_lines_meging_indexing[ind]==1
else None
for ind in range(len(cropped_lines_meging_indexing))]
extracted_conf_value_merged = [extracted_conf_value[ind] # type: ignore
if cropped_lines_meging_indexing[ind]==0
else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2.
if cropped_lines_meging_indexing[ind]==1
else None
for ind in range(len(cropped_lines_meging_indexing))]
extracted_conf_value_merged: List[float] = [extracted_conf_value_merged[ind_cfm]
for ind_cfm in range(len(extracted_texts_merged))
if extracted_texts_merged[ind_cfm] is not None]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
if cropped_lines_meging_indexing[ind] == 0
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
for ind in range(len(cropped_lines_meging_indexing))
if cropped_lines_meging_indexing[ind] >= 0]
extracted_confs_merged = [extracted_confs[ind]
if cropped_lines_meging_indexing[ind] == 0
else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1])
for ind in range(len(cropped_lines_meging_indexing))
if cropped_lines_meging_indexing[ind] >= 0]
return EynollahOcrResult(
extracted_texts_merged=extracted_texts_merged,
extracted_conf_value_merged=extracted_conf_value_merged,
extracted_confs_merged=extracted_confs_merged,
cropped_lines_region_indexer=cropped_lines_region_indexer,
total_bb_coordinates=total_bb_coordinates,
)
@ -652,9 +339,8 @@ class Eynollah_ocr:
cropped_lines_region_indexer = result.cropped_lines_region_indexer
total_bb_coordinates = result.total_bb_coordinates
extracted_texts_merged = result.extracted_texts_merged
extracted_conf_value_merged = result.extracted_conf_value_merged
extracted_confs_merged = result.extracted_confs_merged
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text)
@ -682,79 +368,53 @@ class Eynollah_ocr:
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
image_text.save(out_image_with_text)
text_by_textregion = []
for ind in unique_cropped_lines_region_indexer:
ind = np.array(cropped_lines_region_indexer)==ind
extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
if len(extracted_texts_merged_un)>1:
text_by_textregion_ind = ""
cropped_lines_region_indexer = np.array(cropped_lines_region_indexer)
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
lines_indexer = np.flatnonzero(cropped_lines_region_indexer == n_region)
if not len(lines_indexer):
continue
text_region = ""
next_glue = ""
for indt in range(len(extracted_texts_merged_un)):
if (extracted_texts_merged_un[indt].endswith('') or
extracted_texts_merged_un[indt].endswith('-') or
extracted_texts_merged_un[indt].endswith('¬')):
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1]
for line_idx in lines_indexer:
if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text:
continue
text_line = extracted_texts_merged[line_idx]
if (text_line.endswith(('', '-', '¬')) and
# last line of a region can still be wrapped
# around columns or pages
line_idx < len(lines_indexer) - 1):
text_region += next_glue + text_line[:-1]
next_glue = ""
else:
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt]
text_region += next_glue + text_line
next_glue = " "
text_by_textregion.append(text_by_textregion_ind)
region_textequiv = region.find('{%s}TextEquiv' % page_ns)
if region_textequiv is None:
region_textequiv = ET.SubElement(region, 'TextEquiv')
region_teunicode = region_textequiv.find('{%s}Unicode' % page_ns)
if region_teunicode is None:
region_teunicode = ET.SubElement(region_textequiv, 'Unicode')
region_teunicode.text = text_region
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
line_textequiv = line.find('{%s}TextEquiv' % page_ns)
if line_textequiv is None:
line_textequiv = ET.SubElement(line, 'TextEquiv')
line_teunicode = line_textequiv.find('{%s}Unicode' % page_ns)
if line_teunicode is None:
line_teunicode = ET.SubElement(line_textequiv, 'Unicode')
line_idx = lines_indexer[n_line]
if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text:
line.remove(line_textequiv)
else:
text_by_textregion.append(" ".join(extracted_texts_merged_un))
indexer = 0
indexer_textregion = 0
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
is_textregion_text = False
for childtest in nn:
if childtest.tag.endswith("TextEquiv"):
is_textregion_text = True
if not is_textregion_text:
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
has_textline = False
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
is_textline_text = False
for childtest2 in child_textregion:
if childtest2.tag.endswith("TextEquiv"):
is_textline_text = True
if not is_textline_text:
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
if extracted_conf_value_merged:
text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
unicode_textline.text = extracted_texts_merged[indexer]
else:
for childtest3 in child_textregion:
if childtest3.tag.endswith("TextEquiv"):
for child_uc in childtest3:
if child_uc.tag.endswith("Unicode"):
if extracted_conf_value_merged:
childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
child_uc.text = extracted_texts_merged[indexer]
indexer = indexer + 1
has_textline = True
if has_textline:
if is_textregion_text:
for child4 in nn:
if child4.tag.endswith("TextEquiv"):
for childtr_uc in child4:
if childtr_uc.tag.endswith("Unicode"):
childtr_uc.text = text_by_textregion[indexer_textregion]
else:
unicode_textregion.text = text_by_textregion[indexer_textregion]
indexer_textregion = indexer_textregion + 1
line_textequiv.set('conf', str(round(extracted_confs_merged[line_idx], 2)))
line_teunicode.text = extracted_texts_merged[line_idx]
ET.register_namespace("",page_ns)
self.logger.info("output filename: '%s'", out_file_ocr)
page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None)
def run(
@ -814,18 +474,13 @@ class Eynollah_ocr:
img=img,
page_tree=page_tree,
page_ns=page_ns,
tr_ocr_input_height_and_width = 384
)
else:
result = self.run_cnn(
img=img,
page_tree=page_tree,
page_ns=page_ns,
img_bin=img_bin,
image_width=self.input_shape[1],
image_height=self.input_shape[0],
)
self.write_ocr(

View file

@ -17,9 +17,7 @@ import cv2
import numpy as np
import statistics
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from .eynollah import Eynollah
from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image
from .utils.contour import (
@ -33,23 +31,27 @@ DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8)
class machine_based_reading_order_on_layout:
class Reorder(Eynollah):
def __init__(
self,
*,
model_zoo: EynollahModelZoo,
logger : Optional[logging.Logger] = None,
device: str = '',
):
self.logger = logger or logging.getLogger('eynollah.mbreorder')
self.model_zoo = model_zoo
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
self.model_zoo.load_model('reading_order')
self.setup_models(device=device)
def setup_models(self, device=''):
loadable = ['reading_order']
self.model_zoo.load_models(*loadable, device=device)
for model in loadable:
self.logger.debug("model %s has input shape %s", model,
self.model_zoo.get(model).input_shape)
self.model_zoo.load_models('reading_order')
def read_xml(self, xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
@ -675,7 +677,7 @@ class machine_based_reading_order_on_layout:
tot_counter += 1
batch.append(j)
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
y_pr = self.model_zoo.get('reading_order').predict(input_1 , verbose='0')
y_pr = self.model_zoo.get('reading_order').predict(input_1, verbose=0)
for jb, j in enumerate(batch):
if y_pr[jb][0]>=0.5:
post_list.append(j)

View file

@ -208,22 +208,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
type='Keras',
),
EynollahModelSpec(
category="num_to_char",
variant='',
filename="characters_org.txt",
dist_url=dist_url("ocr"),
type='decoder',
),
EynollahModelSpec(
category="characters",
variant='',
filename="characters_org.txt",
dist_url=dist_url("ocr"),
type='List[str]',
),
EynollahModelSpec(
category="ocr",
variant='tr',
@ -233,20 +217,4 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
type='Keras',
),
EynollahModelSpec(
category="trocr_processor",
variant='',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist_url=dist_url("ocr"),
type='TrOCRProcessor',
),
EynollahModelSpec(
category="trocr_processor",
variant='htr',
filename="models_eynollah/microsoft/trocr-base-handwritten",
dist_url=dist_url("extra"),
type='TrOCRProcessor',
),
])

View file

@ -14,6 +14,19 @@ from .default_specs import DEFAULT_MODEL_SPECS
from .types import AnyModel, T
MODEL_VRAM_LIMITS = {
"binarization": 868, # due to bs 5
"enhancement": 980, # due to bs 3
"col_classifier": 210,
"page": 618,
"textline": 1680, # 954 for bs 1
"region_1_2": 1580,
"region_fl_np": 1756,
"table": 1818,
"reading_order": 632,
"ocr": 850,
}
class EynollahModelZoo:
"""
Wrapper class that handles storage and loading of models for all eynollah runners.
@ -35,7 +48,7 @@ class EynollahModelZoo:
self._overrides = []
if model_overrides:
self.override_models(*model_overrides)
self._loaded: Dict[str, Predictor] = {}
self._loaded: Dict[str, Union[Predictor, AnyModel]] = {}
@property
def model_overrides(self):
@ -70,6 +83,13 @@ class EynollahModelZoo:
model_path = Path(self.model_basedir).joinpath(spec.filename)
else:
model_path = Path(spec.filename)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
if model_path.with_suffix('.onnx').exists():
# prefer ONNX over SavedModel format if it exists
model_path = model_path.with_suffix('.onnx')
return model_path
def load_models(
@ -82,24 +102,31 @@ class EynollahModelZoo:
"""
ret = {} # cannot use self._loaded here, yet first spawn all predictors
for load_args in all_load_args:
load_kwargs = dict(device=device)
if isinstance(load_args, str):
model_category = load_args
load_args = [model_category]
model_category, model_variant = load_args, ""
elif len(load_args) > 2:
# for calls to self.model_path
self.override_models(load_args)
# for calls to Predictor.load_model
model_category, model_variant, model_path = load_args
load_kwargs["model_variant"] = model_variant
load_kwargs["model_path_override"] = model_path
else:
model_category = load_args[0]
load_kwargs = {}
if model_category.endswith('_resized'):
load_args[0] = model_category[:-8]
load_kwargs["resized"] = True
elif model_category.endswith('_patched'):
load_args[0] = model_category[:-8]
load_kwargs["patched"] = True
spec = self.specs.get(model_category, load_args[1] if len(load_args) > 1 else '')
if spec.type in ['Keras'] and spec.category != 'ocr':
ret[model_category] = Predictor(self.logger, self)
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
else:
ret[model_category] = self.load_model(*load_args, **load_kwargs, device=device)
model_category, model_variant = load_args
load_kwargs["model_variant"] = model_variant
# if model_category.endswith('_resized'):
# model_category = model_category[:-8]
# load_kwargs["resized"] = True
# elif model_category.endswith('_patched'):
# model_category = model_category[:-8]
# load_kwargs["patched"] = True
model = Predictor(self.logger, self)
model.load_model(model_category, **load_kwargs)
ret[model_category] = model
self._loaded.update(ret)
return self._loaded
@ -108,31 +135,48 @@ class EynollahModelZoo:
model_category: str,
model_variant: str = '',
model_path_override: Optional[str] = None,
patched: bool = False,
resized: bool = False,
# patched: bool = False,
# resized: bool = False,
device: str = '',
) -> AnyModel:
"""
Load any model
"""
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
if model_path_override:
self.override_models((model_category, model_variant, model_path_override))
model_path = self.model_path(model_category, model_variant)
if model_category == 'ocr' and model_variant == 'tr':
model = self._load_trocr_model(model_path, device=device)
elif model_path.is_dir() and (model_path / "keras_metadata.pb").exists():
# Keras model
model = self._load_keras_model(model_category, model_path, device=device)
elif model_path.is_dir():
# TF-Serving model
model = self._load_serving_model(model_category, model_path, device=device)
elif model_path.suffix == '.onnx':
# ONNX model
model = self._load_onnx_model(model_category, model_path, device=device)
else:
raise ValueError("unknown model type for '%s'" % str(model_path))
model._name = model_category
return model
def get(self, model_category: str) -> Union[Predictor, AnyModel]:
if model_category not in self._loaded:
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
return self._loaded[model_category]
def _configure_tf_device(self, model_category, device=''):
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
from tensorflow.keras.models import load_model
from ..patch_encoder import (
PatchEncoder,
Patches,
wrap_layout_model_patched,
wrap_layout_model_resized,
)
cuda = False
try:
gpus = tf.config.list_physical_devices('GPU')
if device:
if ',' in device:
if ':' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase(model_category, cat):
@ -147,7 +191,14 @@ class EynollahModelZoo:
gpus = gpus[:1] # TF will always use first allowable
tf.config.set_visible_devices(gpus, 'GPU')
for device in gpus:
tf.config.experimental.set_memory_growth(device, True)
# tf.config.experimental.set_memory_growth(device, True)
# dynamic growth never frees memory (to avoid fragmentation),
# so the VRAM requirements end up much larger than feasible
# (for small GPUs); so try hard (calibrated) limits instead:
tf.config.set_logical_device_configuration(
device,
[tf.config.LogicalDeviceConfiguration(
memory_limit=MODEL_VRAM_LIMITS[model_category])])
vendor_name = (
tf.config.experimental.get_device_details(device)
.get('device_name', 'unknown'))
@ -155,95 +206,195 @@ class EynollahModelZoo:
self.logger.info("using GPU %s (%s) for model %s",
device.name,
vendor_name,
model_category + (
"_patched" if patched else
"_resized" if resized else ""))
model_category # + (
# "_patched" if patched else
# "_resized" if resized else "")
)
except RuntimeError:
self.logger.exception("cannot configure GPU devices")
if not cuda:
self.logger.warning("no GPU device available")
if model_path_override:
self.override_models((model_category, model_variant, model_path_override))
model_path = self.model_path(model_category, model_variant)
if model_path.suffix == '.h5' and Path(model_path.stem).exists():
# prefer SavedModel over HDF5 format if it exists
model_path = Path(model_path.stem)
if model_category == 'ocr':
model = self._load_ocr_model(variant=model_variant)
elif model_category == 'num_to_char':
model = self._load_num_to_char()
elif model_category == 'characters':
model = self._load_characters()
elif model_category == 'trocr_processor':
from transformers import TrOCRProcessor
model = TrOCRProcessor.from_pretrained(model_path)
else:
def _configure_torch_device(self, model_category, device=''):
import torch
device0 = torch.device('cpu')
if not device and torch.cuda.is_available():
device = 'GPU' # try
if device and ':' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase('ocr', cat):
device = dev
break
if device and device.startswith('GPU'):
try:
# avoid wasting VRAM on non-transformer models
device0 = torch.device('cuda', int(device[3:] or 0))
name = torch.cuda.get_device_name(device0)
self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name)
except:
self.logger.exception("cannot configure GPU device")
device0 = torch.device('cpu')
if device0.type != 'cuda':
self.logger.warning("no GPU device available")
return device0
def _load_keras_model(self, model_category, model_path, device=''):
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
from tensorflow.keras.models import load_model
from tensorflow.keras.models import Model as KerasModel
from ..training.models import cnn_rnn_ocr_model4inference
self._configure_tf_device(model_category, device=device)
model = load_model(model_path, compile=False)
except Exception as e:
self.logger.error(e)
model = load_model(
model_path, compile=False,
custom_objects=dict(PatchEncoder=PatchEncoder,
Patches=Patches))
model._name = model_category
if resized:
model = wrap_layout_model_resized(model)
model._name = model_category + '_resized'
elif patched:
model = wrap_layout_model_patched(model)
model._name = model_category + '_patched'
else:
model.jit_compile = True
assert isinstance(model, KerasModel)
# from ..patch_encoder import (
# wrap_layout_model_patched,
# wrap_layout_model_resized,
# )
# if resized:
# model = wrap_layout_model_resized(model)
# model._name = model_category + '_resized'
# elif patched:
# model = wrap_layout_model_patched(model)
# model._name = model_category + '_patched'
if model_category == 'ocr':
# cnn-rnn-ocr task model may not be in inference mode, yet
model = cnn_rnn_ocr_model4inference(model, model_path)
model.make_predict_function()
return model
def get(self, model_category: str) -> Predictor:
if model_category not in self._loaded:
raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"')
return self._loaded[model_category]
def _load_serving_model(self, model_category, model_path, device=''):
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
def _load_ocr_model(self, variant: str) -> AnyModel:
self._configure_tf_device(model_category, device=device)
model = tf.saved_model.load(model_path)
model.predict_on_batch = model.serve
model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape)
return model
def _load_onnx_model(self, model_category, model_path, device=''):
import onnxruntime as ort
import numpy as np
providers = ort.get_available_providers()
if device:
if ':' in device:
for spec in device.split(','):
cat, dev = spec.split(':')
if fnmatchcase(model_category, cat):
device = dev
break
if device == 'CPU':
gpu = -1
else:
assert device.startswith('GPU')
gpu = int(device[3:] or "0")
else:
gpu = 0 # try first allowable
# configure and prioritise
if 'CUDAExecutionProvider' in providers:
providers.remove('CUDAExecutionProvider')
if gpu >= 0:
providers = [('CUDAExecutionProvider', {
'device_id': gpu,
# 'arena_extend_strategy': 'kNextPowerOfTwo',
'gpu_mem_limit': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
# 'cudnn_conv_algo_search': 'EXHAUSTIVE',
# 'do_copy_in_default_stream': True,
# ...
})] + providers
if 'TensorrtExecutionProvider' in providers:
providers.remove('TensorrtExecutionProvider')
if gpu >= 0:
providers = [('TensorrtExecutionProvider', {
'device_id': gpu,
'trt_max_workspace_size': MODEL_VRAM_LIMITS[model_category] * 1024 * 1024,
# 'trt_fp16_enable': True,
# 'trt_engine_cache_enable': True,
# 'trt_timing_cache_enable': True,
# ...
})] + providers
model = ort.InferenceSession(
model_path,
providers=providers)
# FIXME: notify about selected provider/device
model_inputs = [model_input.name
for model_input in model.get_inputs()]
model_outputs = [model_output.name
for model_output in model.get_outputs()]
def predict_onnx(inputs):
if len(model_inputs) == 1:
inputs = [inputs]
outputs = model.run(model_outputs, {
model_input:
input_data.astype(
# models expect data_type() == 'tensor(float)', but np.float16 is 'tensor(float16)'
# FIXME: do this dynamically (but how to convert .type to np.dtype?)
np.float32 if input_data.dtype in [np.float16, np.float64] else
input_data.dtype)
for model_input, input_data in zip(model_inputs, inputs)
})
if len(model_outputs) == 1:
outputs = outputs[0]
return outputs
model.predict_on_batch = predict_onnx
model.input_shape = model.get_inputs()[0].shape
return model
def _load_trocr_model(self, model_path, device: str = "") -> AnyModel:
"""
Load OCR model
"""
from tensorflow.keras.models import Model as KerasModel
from tensorflow.keras.models import load_model
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
import numpy as np
ocr_model_dir = self.model_path('ocr', variant)
if variant == 'tr':
from transformers import VisionEncoderDecoderModel
ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
assert isinstance(ret, VisionEncoderDecoderModel)
return ret
device = self._configure_torch_device('ocr', device=device)
proc = TrOCRProcessor.from_pretrained(model_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path)
assert isinstance(model, VisionEncoderDecoderModel)
model.to(device)
def predict_torch(inputs):
output = model.generate(
proc(inputs, return_tensors="pt").pixel_values.to(device),
# beam search instead of greedy decoding:
num_beams=4,
# also return probability
output_scores=True,
return_dict_in_generate=True)
if output.sequences_scores is not None:
# log-prob averaged over length
conf = output.sequences_scores.exp().clamp(0.0, 1.0).cpu().numpy()
else:
ocr_model = load_model(ocr_model_dir, compile=False)
assert isinstance(ocr_model, KerasModel)
return KerasModel(
ocr_model.get_layer(name="image").input, # type: ignore
ocr_model.get_layer(name="dense2").output, # type: ignore
)
def _load_characters(self) -> List[str]:
"""
Load encoding for OCR
"""
with open(self.model_path('num_to_char'), "r") as config_file:
return json.load(config_file)
def _load_num_to_char(self) -> 'StringLookup':
"""
Load decoder for OCR
"""
from tensorflow.keras.layers import StringLookup
characters = self._load_characters()
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
# Mapping integers back to original characters.
return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
conf = np.ones(len(output.sequences), dtype=float)
text = proc.batch_decode(
output.sequences,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
# we must convert to ndarray for Predictor resultq to work
text = np.array(text)
return text, conf
model.predict_on_batch = predict_torch
# not actually needed (image processor does resize itself)
# no batch dimension (images passed as list w/ varying shapes)
model.input_shape = (None,
None,
len(proc.image_processor.image_mean))
return model
def __str__(self):
return tabulate(
@ -277,5 +428,5 @@ class EynollahModelZoo:
"""
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
for needle in list(self._loaded.keys()):
if isinstance(self._loaded[needle], Predictor):
self._loaded[needle].shutdown()
del self._loaded[needle]

View file

@ -1,10 +1,8 @@
# NOTE: For predictable order of imports of torch/shapely/tensorflow
# this must be the first import of the CLI!
from .eynollah_imports import imported_libs
from .processor import EynollahProcessor
from click import command
from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor
from .processor import EynollahProcessor
@command()
@ocrd_cli_options
def main(*args, **kwargs):

View file

@ -6,8 +6,8 @@ from tensorflow.keras import layers, models
class PatchEncoder(layers.Layer):
# 441=21*21 # 14*14 # 28*28
def __init__(self, num_patches=441, projection_dim=64):
super().__init__()
def __init__(self, num_patches=441, projection_dim=64, name='encode_patches'):
super().__init__(name=name)
self.num_patches = num_patches
self.projection_dim = projection_dim
self.projection = layers.Dense(self.projection_dim)
@ -24,8 +24,8 @@ class PatchEncoder(layers.Layer):
**super().get_config())
class Patches(layers.Layer):
def __init__(self, patch_size_x=1, patch_size_y=1):
super().__init__()
def __init__(self, patch_size_x=1, patch_size_y=1, name='extract_patches'):
super().__init__(name=name)
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y

View file

@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import List, Dict
from typing import List, Dict, Tuple, Union
import logging
import logging.handlers
import multiprocessing as mp
@ -8,6 +8,7 @@ import numpy as np
from .utils.shm import share_ndarray, ndarray_shared
QSIZE = 200
ArrayT = Union[np.ndarray, Tuple[np.ndarray]]
class Predictor(mp.context.SpawnProcess):
@ -40,10 +41,10 @@ class Predictor(mp.context.SpawnProcess):
def input_shape(self):
return self({})
def predict(self, data: dict, verbose=0):
def predict(self, data: ArrayT, verbose=0) -> ArrayT:
return self(data)
def __call__(self, data: dict):
def __call__(self, data: Union[ArrayT, Dict]) -> Union[ArrayT, Tuple]:
# unusable as per python/cpython#79967
#with self.jobid.get_lock():
# would work, but not public:
@ -55,7 +56,15 @@ class Predictor(mp.context.SpawnProcess):
self.taskq.put((jobid, data))
#self.logger.debug("sent shape query task '%d' for model '%s'", jobid, self.name)
return self.result(jobid)
with share_ndarray(data) as shared_data:
with ExitStack() as stack:
if isinstance(data, tuple):
# multi-input
shared_data = []
for data0 in data:
shared_data.append(stack.enter_context(share_ndarray(data0)))
shared_data = tuple(shared_data)
else:
shared_data = stack.enter_context(share_ndarray(data))
self.taskq.put((jobid, shared_data))
#self.logger.debug("sent prediction task '%d' for model '%s': %s", jobid, self.name, shared_data)
return self.result(jobid)
@ -67,6 +76,14 @@ class Predictor(mp.context.SpawnProcess):
result = self.results.pop(jobid)
if isinstance(result, Exception):
raise Exception(f"predictor {self.name} failed for {jobid}") from result
elif isinstance(result, tuple) and isinstance(result[0], dict):
# multi-output
result1 = []
for result0 in result:
with ndarray_shared(result0) as shared_result0:
result1.append(np.copy(shared_result0))
result = result1
self.closable.append(jobid)
elif isinstance(result, dict):
with ndarray_shared(result) as shared_result:
result = np.copy(shared_result)
@ -111,6 +128,8 @@ class Predictor(mp.context.SpawnProcess):
"binarization": 4,
"enhancement": 4,
"reading_order": 4,
"ocr": 8,
"ocr_tr": 2,
# medium size (672x672x3)...
"textline": 2,
# large models...
@ -126,8 +145,20 @@ class Predictor(mp.context.SpawnProcess):
self.resultq.put((jobid, result))
#self.logger.debug("sent result for '%d': %s", jobid, result)
else:
tasks = [(jobid, shared_data)]
if self.name == 'ocr_tr':
# this model takes a list of (image) tensors
# of heterogeneous shape as input,
# resizing them internally;
# so this looks like multi-input
multi_input = True
batch_size = len(shared_data)
elif isinstance(shared_data, tuple):
multi_input = True
batch_size = shared_data[0]['shape'][0]
else:
multi_input = False
batch_size = shared_data['shape'][0]
tasks = [(jobid, shared_data)]
while (not self.taskq.empty() and
# climb to target batch size
batch_size * len(tasks) < REBATCH_SIZE):
@ -136,7 +167,7 @@ class Predictor(mp.context.SpawnProcess):
# add to our batch
tasks.append((jobid0, shared_data0))
else:
# immediately anser
# immediately answer
self.resultq.put((jobid0, self.model.input_shape))
if len(tasks) > 1:
self.logger.debug("rebatching %d '%s' tasks of batch size %d",
@ -147,11 +178,25 @@ class Predictor(mp.context.SpawnProcess):
for jobid, shared_data in tasks:
#self.logger.debug("predicting '%d' with model '%s': %s", jobid, self.name, shared_data)
jobs.append(jobid)
if multi_input:
data.append(tuple(stack.enter_context(ndarray_shared(shared_data0))
for shared_data0 in shared_data))
else:
data.append(stack.enter_context(ndarray_shared(shared_data)))
if multi_input:
data = list(np.concatenate(data0)
for data0 in zip(*data))
else:
data = np.concatenate(data)
#result = self.model.predict(data, verbose=0)
# faster, less VRAM
result = self.model.predict_on_batch(data)
if isinstance(result, tuple):
multi_output = True
results = zip(*(np.split(result0, len(jobs))
for result0 in result))
else:
multi_output = False
results = np.split(result, len(jobs))
#self.logger.debug("sharing result array for '%d'", jobid)
with ExitStack() as stack:
@ -160,6 +205,10 @@ class Predictor(mp.context.SpawnProcess):
# but don't want to wait either, so track closing
# context per job, and wait for closable signal
# from client
if multi_output:
result = tuple(stack.enter_context(share_ndarray(result0))
for result0 in result)
else:
result = stack.enter_context(share_ndarray(result))
closing[jobid] = stack.pop_all()
self.resultq.put((jobid, result))
@ -174,8 +223,11 @@ class Predictor(mp.context.SpawnProcess):
def load_model(self, *load_args, **load_kwargs):
assert len(load_args)
self.name = '_'.join(list(load_args[:1]) +
list(load_kwargs[key] for key in load_kwargs
if key == 'model_variant') +
list(key for key in load_kwargs
if key != 'device'))
if key in ['patched', 'resized']
and load_kwargs[key]))
self.load_args = load_args
self.load_kwargs = load_kwargs
self.start() # call run() in subprocess
@ -194,17 +246,18 @@ class Predictor(mp.context.SpawnProcess):
def shutdown(self):
# do not terminate from forked processor instances
if mp.parent_process() is None:
if not hasattr(self, 'model'):
self.stopped.set()
self.join()
self.taskq.close()
self.taskq.cancel_join_thread()
self.resultq.close()
self.resultq.cancel_join_thread()
self.logq.close()
self.terminate()
#self.terminate()
else:
del self.model
def __del__(self):
#self.logger.debug(f"deinit of {self} in {mp.current_process().name}")
#self.logger.debug(f"deinit of {self.name} in {mp.current_process().name}")
self.shutdown()

View file

@ -7,7 +7,8 @@ import sys
from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save
from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli
from .train import ex
from .train import train_cli
from .convert import convert_cli
from .extract_line_gt import linegt_cli
<<<<<<< HEAD
from .weights_ensembling import ensemble_cli
@ -16,13 +17,6 @@ from .weights_ensembling import main as ensemble_cli
from .generate_or_update_cnn_rnn_ocr_character_list import main as update_ocr_characters_cli
>>>>>>> integrating_trocr_and_torch_ensembling_and_updating_characters_list
@click.command(context_settings=dict(
ignore_unknown_options=True,
))
@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED)
def train_cli(sacred_args):
ex.run_commandline([sys.argv[0]] + list(sacred_args))
@click.group('training')
def main():
pass
@ -31,6 +25,7 @@ main.add_command(build_model_load_pretrained_weights_and_save)
main.add_command(generate_gt_cli, 'generate-gt')
main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train')
main.add_command(convert_cli, 'convert')
main.add_command(linegt_cli, 'export_textline_images_and_text')
main.add_command(ensemble_cli, 'ensembling')
main.add_command(update_ocr_characters_cli, 'generate_or_update_cnn_rnn_ocr_character_list')

View file

@ -0,0 +1,103 @@
import os
from pathlib import Path
from shutil import copy2
import logging
import json
import click
@click.command(context_settings=dict(
help_option_names=['-h', '--help'],
show_default=True))
@click.option(
"--rebuild",
"-r",
help="build new model from code and then load existing weights (requires input in SavedModel directory format with config.json present)",
is_flag=True
)
@click.option(
"--format",
"-f",
"format_",
help="data format to convert to",
type=click.Choice(["hdf5", "keras", "tf", "tf-serving", "onnx"]),
default="tf"
)
@click.option(
"--in",
"-i",
"in_",
help="path to input model (file in hdf5 / keras format, or directory in tf format)",
required=True,
type=click.Path(exists=True, dir_okay=True)
)
@click.option(
"--out",
"-o",
help="path to output model (file in hdf5 / keras / onnx format, or directory in tf / tf-serving format)",
required=True,
type=click.Path(exists=False, dir_okay=True)
)
def convert_cli(rebuild, format_, in_, out):
"""
convert models for inference
Load model from path, optionally by rebuilding, convert to output format and write model to path.
"""
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.models import Model as KerasModel
model_path = Path(in_)
config_path = model_path / "config.json"
if model_path.is_dir():
assert (model_path / "keras_metadata.pb").exists(), (
"input directory must be Keras model in SavedModel format")
if rebuild:
from .train import ex
from .models import get_model
assert config_path.exists(), (
"rebuilding requires input model in SavedModel format with config.json")
# merge defaults with existing config file
ex.add_config(str(config_path))
# some models deviate between training and inference
ex.add_config(inference=True)
# just retrieve final config (via pseudo-run)
ex.main(lambda: 0)
config = ex.run(options={'--loglevel': 'ERROR'}).config
# use the config to capture the model builder
model = get_model(config, logging.root)
model.load_weights(model_path).assert_existing_objects_matched().expect_partial()
else:
from .models import cnn_rnn_ocr_model4inference
model = load_model(model_path, compile=False)
if isinstance(model, KerasModel):
# cnn-rnn-ocr task deviates between training and inference
model = cnn_rnn_ocr_model4inference(model, model_path)
if format_ in ["hdf5", "keras", "tf"]:
kwargs = {"save_format": {"hdf5": "h5"}.get(format_, format_)}
if format_ != "keras":
kwargs["include_optimizer"] = False
model.save(out, **kwargs)
elif format_ == "tf-serving":
model.export(out)
elif format_ == "onnx":
import tf2onnx
tf2onnx.convert.from_keras(model, opset=18, output_path=out)
else:
raise ValueError("unknown output format '%s'" % format_)
# copy config.json if possible
if config_path.exists() and format_ in ['tf', 'tf-serving']:
copy2(config_path, Path(out) / config_path.name)

View file

@ -1,4 +1,5 @@
import os
import json
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
@ -23,6 +24,7 @@ from tensorflow.keras.layers import (
Reshape,
UpSampling2D,
ZeroPadding2D,
StringLookup,
add,
concatenate
)
@ -58,6 +60,50 @@ class CTCLayer(Layer):
# At test time, just return the computed predictions.
return y_pred
class CTCDecoder(Layer):
def call(self, inputs):
n_samples = tf.shape(inputs)[0]
n_steps = inputs.shape[1]
n_classes = inputs.shape[2]
lengths = tf.ones(n_samples, dtype=tf.int32) * n_steps
## Keras beam search seems to mess with double letters
## but Keras greedy sometimes removes arbitrary letters
# outputs, logits = tf.keras.backend.ctc_decode(inputs,
# lengths,
# beam_width=20
# greedy=False, # True,
# # backend does not allow these kwargs
# #merge_repeated=False,
# #mask_index=inputs.shape[2]-1,
# )
# tf.nn.ctc_*_decoder (in contrast to tf.keras.backend.ctc_decode)
# needs logits instead of probs and time-major (batch 2nd dim)
inputs = tf.math.log(
tf.transpose(inputs, perm=[1, 0, 2]) + tf.keras.backend.epsilon()
)
# tf.nn.ctc_greedy_decoder() is not as precise
# tf.compat.v1.nn.ctc_beam_search_decoder() also needs merge_repeated=False
decoded, logits = tf.nn.ctc_beam_search_decoder(
inputs,
lengths,
beam_width=10,
top_paths=2,
)
# get top path for all sequences in batch
decoded = decoded[0]
logits = logits[:, 0] - logits[:, 1]
# convert to dense
outputs = tf.SparseTensor(decoded.indices, decoded.values,
(n_samples, n_steps))
outputs = tf.sparse.to_dense(sp_input=outputs, default_value=-1)
# # drop non-tokens (-1) and OOV (0)
# result = []
# for output in outputs:
# result.append(tf.gather(output, tf.where(output > 0)))
# outputs = tf.stack(result)
probs = tf.exp(-logits)
return outputs, probs
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = Dense(units, activation=tf.nn.gelu)(x)
@ -311,11 +357,10 @@ def transformer_block(img,
#assert isinstance(x, Layer)
encoded_patches = tf.reshape(encoded_patches,
[-1,
img.shape[1],
encoded_patches = Reshape((img.shape[1],
img.shape[2],
projection_dim // (patchsize_x * patchsize_y)])
projection_dim // (patchsize_x * patchsize_y)),
name="reshape_patches")(encoded_patches)
return encoded_patches
def vit_resnet50_unet(num_patches,
@ -425,11 +470,11 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
return model
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
input_img = Input(shape=(image_height, image_width, 3), name="image")
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None):
inputs = Input(shape=(image_height, image_width, 3), name="image")
labels = Input(name="label", shape=(None,))
x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs)
x = BatchNormalization(name="bn1")(x)
x = Activation("relu", name="relu1")(x)
x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
@ -462,43 +507,135 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d)
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
x = Reshape(target_shape=new_shape, name="reshape")(x)
x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d)
x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d)
x = Reshape(new_shape, name="reshape")(x)
x2d = Reshape(new_shape2, name="reshape2")(x2d)
x4d = Reshape(new_shape4, name="reshape4")(x4d)
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
xrnn2dup = Reshape((xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
xrnn4dup = Reshape((xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
out = BatchNormalization(name="bn9")(out)
out = Activation("relu", name="relu9")(out)
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
out = Dense(n_classes, activation="softmax", name="dense2")(out)
if inference:
# add second path for binarization
inputs_bin = Input(shape=(image_height, image_width, 3), name="image_bin")
out_bin = Model(inputs, out)(inputs_bin)
# ensemble raw results
out = 0.5 * (out + out_bin)
# get tf.string batch
out, prob = CTCDecoder()(out)
# decode int to str
with open(characters_txt_file, "r") as voc_file:
voc = json.load(voc_file)
char2num = StringLookup(vocabulary=voc)
voc = char2num.get_vocabulary()
num2char = StringLookup(vocabulary=voc, invert=True)
output = num2char(out)
# avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed)
output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc)))
return Model((inputs, inputs_bin), (output, prob))
# Add CTC layer for calculating CTC loss at each step.
output = CTCLayer(name="ctc_loss")(labels, out)
out = CTCLayer(name="ctc_loss")(labels, out)
model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer")
return Model((inputs, labels), out)
def cnn_rnn_ocr_model4inference(model, model_path):
"""convert trained cnn-rnn-ocr model to inference model post-hoc"""
try:
model.get_layer(name='ctc_loss')
except ValueError:
# likely already converted
return model
else:
inputs = model.get_layer(name='image').input
output = model.get_layer(name='dense2').output
inputs_bin = Input(inputs.shape[1:], name='image_bin')
output_bin = Model(inputs, output)(inputs_bin)
output = 0.5 * (output + output_bin)
output, prob = CTCDecoder()(output)
with open(model_path / "characters_org.txt", "r") as voc_file:
voc = json.load(voc_file)
char2num = StringLookup(vocabulary=voc)
voc = char2num.get_vocabulary()
num2char = StringLookup(vocabulary=voc, invert=True)
output = num2char(output)
# avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed)
output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc)))
inputs = (inputs, inputs_bin)
outputs = (output, prob)
return Model(inputs, outputs)
def get_model(config, logger):
from sacred.config import create_captured_function
task = config['task']
if task in ["segmentation", "enhancement", "binarization"]:
if config['backbone_type'] == 'nontransformer':
builder = resnet50_unet
else:
num_patches_x, num_patches_y = config['transformer_num_patches_xy']
num_patches = num_patches_x * num_patches_y
if config['transformer_cnn_first']:
builder = vit_resnet50_unet
multiple = 32
else:
builder = vit_resnet50_unet_transformer_before_cnn
multiple = 1
assert config['input_height'] == (
num_patches_y * config['transformer_patchsize_y'] * multiple), (
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
"input_height should be equal to "
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
assert config['input_width'] == (
num_patches_x * config['transformer_patchsize_x'] * multiple), (
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
"input_width should be equal to "
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
assert 0 == (config['transformer_projection_dim'] %
(config['transformer_patchsize_y'] *
config['transformer_patchsize_x'])), (
"transformer_projection_dim error: "
"The remainder when parameter transformer_projection_dim is divided by "
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
config['num_patches'] = num_patches
elif task == "cnn-rnn-ocr":
builder = cnn_rnn_ocr_model
elif task=='classification':
builder = resnet50_classifier
elif task=='reading_order':
builder = machine_based_reading_order_model
else:
raise ValueError("unknown model task '%s'" % task)
builder = create_captured_function(builder)
builder.config = config
builder.logger = logger
return builder()

View file

@ -4,38 +4,65 @@ MODELS_SRC = models_eynollah
MODELS_DST = reloaded/models_eynollah
# $(MODELS_DST)/eynollah-binarization_20210425 \
# $(MODELS_DST)/eynollah-column-classifier_20210425 \
# $(MODELS_DST)/eynollah-enhancement_20210425 \
# $(MODELS_DST)/eynollah-main-regions-aug-rotation_20210425 \
# $(MODELS_DST)/eynollah-main-regions-aug-scaling_20210425 \
# $(MODELS_DST)/eynollah-main-regions-ensembled_20210425 \
# $(MODELS_DST)/eynollah-main-regions_20220314 \
# $(MODELS_DST)/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18 \
# $(MODELS_DST)/eynollah-tables_20210319 \
# $(MODELS_DST)/model_eynollah_ocr_cnnrnn_20250930 \
# eynollah-main-regions-aug-rotation_20210425
# eynollah-main-regions-aug-scaling_20210425
# eynollah-main-regions-ensembled_20210425
# eynollah-main-regions_20220314
# eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18
# eynollah-tables_20210319
RELOADABLE_MODELS = \
$(MODELS_DST)/model_eynollah_page_extraction_20250915 \
$(MODELS_DST)/model_eynollah_reading_order_20250824 \
$(MODELS_DST)/modelens_e_l_all_sp_0_1_2_3_4_171024 \
$(MODELS_DST)/modelens_full_lay_1__4_3_091124 \
$(MODELS_DST)/modelens_table_0t4_201124 \
$(MODELS_DST)/modelens_textline_0_1__2_4_16092024
CURRENT_MODELS :=
CURRENT_MODELS += model_eynollah_page_extraction_20250915
CURRENT_MODELS += model_eynollah_reading_order_20250824
CURRENT_MODELS += modelens_e_l_all_sp_0_1_2_3_4_171024
CURRENT_MODELS += modelens_full_lay_1__4_3_091124
CURRENT_MODELS += modelens_table_0t4_201124
CURRENT_MODELS += modelens_textline_0_1__2_4_16092024
CURRENT_MODELS += model_eynollah_ocr_cnnrnn_20250930
CURRENT_MODELS += eynollah-binarization_20210425
CURRENT_MODELS += eynollah-column-classifier_20210425
CURRENT_MODELS += eynollah-enhancement_20210425
all: $(RELOADABLE_MODELS)
all: tf-serving
tf-serving: $(CURRENT_MODELS:%=$(MODELS_DST)/%)
keras: $(CURRENT_MODELS:%=$(MODELS_DST)/%.keras)
hdf5: $(CURRENT_MODELS:%=$(MODELS_DST)/%.h5)
onnx: $(CURRENT_MODELS:%=$(MODELS_DST)/%.onnx)
$(MODELS_DST)/%: $(MODELS_SRC)/%
mkdir -p $@
test -e $</config.json || exit 1
eynollah-training train --force \
with $</config.json \
reload_weights=True \
continue_training=False \
dir_output=$(dir $@) \
dir_of_start_model=$< \
2>&1 | tee $(notdir $<).log
cp $</config.json $@/config.json
eynollah-training convert \
$(and $(wildcard $</config.json),--rebuild) \
--in $< \
--format tf-serving \
--out $@ \
2>&1 | tee $(notdir $<).tf-serving.log
$(MODELS_DST)/%.keras: $(MODELS_SRC)/%
eynollah-training convert \
$(and $(wildcard $</config.json),--rebuild) \
--in $< \
--format keras \
--out $@ \
2>&1 | tee $(notdir $<).keras.log
$(MODELS_DST)/%.h5: $(MODELS_SRC)/%
eynollah-training convert \
$(and $(wildcard $</config.json),--rebuild) \
--in $< \
--format hdf5 \
--out $@ \
2>&1 | tee $(notdir $<).hdf5.log
$(MODELS_DST)/%.onnx: $(MODELS_SRC)/%
if jq -e '.task == "segmentation" and .backbone_type == "transformer"' $</config.json &>/dev/null; then \
echo skipping $@: vision transformer architecture currently does not work with ONNX; else \
eynollah-training convert \
$(and $(wildcard $</config.json),--rebuild) \
--in $< \
--format onnx \
--out $@ \
2>&1 | tee $(notdir $<).onnx.log; fi
compare:
for i in `find $(MODELS_DST) -mindepth 2`;do \
@ -43,6 +70,5 @@ compare:
du -bs $$n $$i ; \
done
clear:
rm -rf $(MODELS_DST)

View file

@ -2,6 +2,7 @@ import os
import sys
import io
import json
import click
from tqdm import tqdm
@ -22,7 +23,6 @@ from tensorflow.keras.layers import StringLookup
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.backend import one_hot
from sacred import Experiment
from sacred.config import create_captured_function
import numpy as np
import cv2
@ -37,16 +37,9 @@ from .metrics import (
connected_components_loss,
)
from .models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
RESNET50_WEIGHTS_URL,
get_model
)
from .utils import (
generate_arrays_from_folder_reading_order,
@ -372,10 +365,9 @@ def config_params():
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}")
reload_weights = False # Set true to build new model from config, load weights from dir_of_start_model, save under dir_output and exit.
continue_training = False # Whether to continue training an existing model.
dir_of_start_model = '' # Directory of model checkpoint to load to continue training or load weights from. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
if continue_training:
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
@ -396,7 +388,6 @@ def run(_config,
weight_decay,
learning_rate,
continue_training,
reload_weights,
save_interval,
augmentation,
# dependent config keys need a default,
@ -494,58 +485,15 @@ def run(_config,
if task == "enhancement":
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
assert not weighted_loss, "for enhancement, weighted loss does not apply"
if continue_training:
custom_objects = dict()
if is_loss_soft_dice:
custom_objects.update(soft_dice_loss=soft_dice_loss)
elif weighted_loss:
custom_objects.update(loss=weighted_categorical_crossentropy(weights))
if backbone_type == 'transformer':
custom_objects.update(PatchEncoder=PatchEncoder,
Patches=Patches)
model = load_model(dir_of_start_model, compile=False,
custom_objects=custom_objects)
model = load_model(dir_of_start_model, compile=False)
else:
index_start = 0
if backbone_type == 'nontransformer':
model = resnet50_unet(n_classes,
input_height,
input_width,
task,
weight_decay,
pretraining)
else:
num_patches_x = transformer_num_patches_xy[0]
num_patches_y = transformer_num_patches_xy[1]
num_patches = num_patches_x * num_patches_y
if transformer_cnn_first:
model_builder = vit_resnet50_unet
multiple = 32
else:
model_builder = vit_resnet50_unet_transformer_before_cnn
multiple = 1
assert input_height == (
num_patches_y * transformer_patchsize_y * multiple), (
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
"input_height should be equal to "
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
assert input_width == (
num_patches_x * transformer_patchsize_x * multiple), (
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
"input_width should be equal to "
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
assert 0 == (transformer_projection_dim %
(transformer_patchsize_y * transformer_patchsize_x)), (
"transformer_projection_dim error: "
"The remainder when parameter transformer_projection_dim is divided by "
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
model_builder = create_captured_function(model_builder)
model_builder.config = _config
model_builder.logger = _log
model = model_builder(num_patches)
model = get_model(_config, _log)
if dir_of_start_model:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
_log.info("reloaded weights from %s", dir_of_start_model)
assert model is not None
#if you want to see the model structure just uncomment model summary.
@ -576,15 +524,6 @@ def run(_config,
optimizer=Adam(learning_rate=learning_rate),
metrics=metrics)
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
model.save(dir_save, include_optimizer=False)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
return
if not data_is_provided:
# first create a directory in output for both training and evaluations
# in order to flow data from these directories.
@ -725,10 +664,11 @@ def run(_config,
model = load_model(dir_of_start_model)
else:
index_start = 0
model = cnn_rnn_ocr_model(image_height=input_height,
image_width=input_width,
n_classes=n_classes,
max_seq=max_len)
model = get_model(_config, _log)
if dir_of_start_model:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
_log.info("reloaded weights from %s", dir_of_start_model)
#initial_learning_rate = 1e-4
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
#alpha = 0.01
@ -739,15 +679,6 @@ def run(_config,
#print(model.summary())
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
model.save(dir_save, include_optimizer=False)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
return
# todo: use Dataset.map() on Dataset.list_files()
def get_dataset(dir_img, dir_lab):
def gen():
@ -789,25 +720,15 @@ def run(_config,
model = load_model(dir_of_start_model, compile=False)
else:
index_start = 0
model = resnet50_classifier(n_classes,
input_height,
input_width,
weight_decay,
pretraining)
model = get_model(_config, _log)
if dir_of_start_model:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
_log.info("reloaded weights from %s", dir_of_start_model)
model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
metrics=['accuracy', F1Score(average='macro', name='f1')])
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
model.save(dir_save, include_optimizer=False)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
return
list_classes = list(classification_classes_name.values())
data_args = dict(label_mode="categorical",
class_names=list_classes,
@ -962,11 +883,10 @@ def run(_config,
model = load_model(dir_of_start_model, compile=False)
else:
index_start = 0
model = machine_based_reading_order_model(n_classes,
input_height,
input_width,
weight_decay,
pretraining)
model = get_model(_config, _log)
if dir_of_start_model:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
_log.info("reloaded weights from %s", dir_of_start_model)
#f1score_tot = [0]
model.compile(loss="binary_crossentropy",
@ -974,15 +894,6 @@ def run(_config,
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
metrics=['accuracy'])
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
model.save(dir_save, include_optimizer=False)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
return
dir_flow_train_imgs = os.path.join(dir_train, 'images')
dir_flow_train_labels = os.path.join(dir_train, 'labels')
@ -1015,3 +926,23 @@ def run(_config,
model_dir = os.path.join(dir_out,'model_best')
model.save(model_dir)
'''
@click.command(context_settings=dict(
ignore_unknown_options=True,
))
@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED)
def train_cli(sacred_args):
"""
train model on extracted GT
SACRED_ARGS as per CLI interface of Sacred, cf.
https://sacred.readthedocs.io/en/stable/command_line.html:
\b
To configure the learning task, pass the string `with`,
followed by any number of
- config JSON file paths
- parameter overrides in the form of key=value
(where the later settings will override the former).
"""
ex.run_commandline([sys.argv[0]] + list(sacred_args))

View file

@ -68,6 +68,7 @@ def run_ensembling(dir_models, out, framework):
@click.option(
"--in",
"-i",
"in_",
help="input directory of checkpoint models to be read",
multiple=True,
required=True,

View file

@ -2,6 +2,7 @@ from typing import Iterable, List, Tuple
from logging import getLogger
import time
import math
from itertools import islice
try:
import matplotlib.pyplot as plt
@ -33,6 +34,11 @@ def pairwise(iterable):
yield a, b
a = b
def batched(iterable, n):
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
yield batch
def return_multicol_separators_x_start_end(
regions_without_separators, peak_points, top, bot,
x_min_hor_some, x_max_hor_some, cy_hor_some, y_min_hor_some, y_max_hor_some):

View file

@ -11,7 +11,7 @@ from shapely.geometry.polygon import orient
from shapely import set_precision, affinity
from shapely.ops import unary_union, nearest_points
from .rotate import rotate_image, rotation_image_new
from .rotate import rotate_image
def contours_in_same_horizon(cy_main_hor):
"""
@ -120,94 +120,6 @@ def return_contours_of_interested_region(region_pre_p, label, min_area=0.0002, d
dilate=dilate)
return contours_imgs
def do_work_of_contours_in_image(contour, index_r_con, img, slope_first):
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
img_copy = cv2.fillPoly(img_copy, pts=[contour], color=1)
img_copy = rotation_image_new(img_copy, -slope_first)
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
return cont_int[0], index_r_con
def get_textregion_contours_in_org_image_multi(cnts, img, slope_first, map=map):
if not len(cnts):
return [], []
results = map(partial(do_work_of_contours_in_image,
img=img,
slope_first=slope_first,
),
cnts, range(len(cnts)))
return tuple(zip(*results))
def get_textregion_contours_in_org_image(cnts, img, slope_first):
cnts_org = []
# print(cnts,'cnts')
for i in range(len(cnts)):
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
img_copy = cv2.fillPoly(img_copy, pts=[cnts[i]], color=1)
# plt.imshow(img_copy)
# plt.show()
# print(img.shape,'img')
img_copy = rotation_image_new(img_copy, -slope_first)
##print(img_copy.shape,'img_copy')
# plt.imshow(img_copy)
# plt.show()
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
# print(np.shape(cont_int[0]))
cnts_org.append(cont_int[0])
return cnts_org
def get_textregion_confidences_old(cnts, img, slope_first):
zoom = 3
img = cv2.resize(img, (img.shape[1] // zoom,
img.shape[0] // zoom),
interpolation=cv2.INTER_NEAREST)
cnts_org = []
for cnt in cnts:
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
img_copy = cv2.fillPoly(img_copy, pts=[cnt // zoom], color=1)
img_copy = rotation_image_new(img_copy, -slope_first).astype(np.uint8)
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
cnts_org.append(cont_int[0] * zoom)
return cnts_org
def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first, confidence_matrix):
img_copy = np.zeros(img.shape[:2], dtype=np.uint8)
img_copy = cv2.fillPoly(img_copy, pts=[contour_par], color=1)
confidence_matrix_mapped_with_contour = confidence_matrix * img_copy
confidence_contour = np.sum(confidence_matrix_mapped_with_contour) / float(np.sum(img_copy))
img_copy = rotation_image_new(img_copy, -slope_first).astype(np.uint8)
_, thresh = cv2.threshold(img_copy, 0, 255, 0)
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
if len(cont_int)==0:
cont_int = [contour_par]
confidence_contour = 0
else:
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
return cont_int[0], index_r_con, confidence_contour
def get_region_confidences(cnts, confidence_matrix):
if not len(cnts):
return []
@ -418,7 +330,7 @@ def estimate_skew_contours(contours):
if not np.any(usable):
raise ValueError("not enough contours with consistent length")
if np.count_nonzero(usable) == 1:
return angle_in[usable]
return angle_in[usable][0]
# 4. there is no way to distinguish between +90 and -89.9 here,
# so map to [0,180] when calculating averages, then map back to [-90,90]
# (we don't want -90 and +89 to average zero, or +1 and +179 to average 90)

View file

@ -2,10 +2,6 @@ import math
import cv2
def rotation_image_new(img, thetha):
rotated = rotate_image(img, thetha)
return rotate_max_area_new(img, rotated, thetha)
def rotate_image(img_patch, slope):
(h, w) = img_patch.shape[:2]
center = (w // 2, h // 2)

View file

@ -3,15 +3,20 @@ import copy
import numpy as np
import cv2
import tensorflow as tf
# avoid module-level import:
# import tensorflow as tf
# (wait for tf-keras and logging setup in ModelZoo.load_model)
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
from PIL import Image, ImageDraw, ImageFont
from . import pairwise
from .resize import resize_image
def decode_batch_predictions(pred, num_to_char, max_len = 128):
import tensorflow as tf
# input_len is the product of the batch size and the
# number of time steps.
input_len = np.ones(pred.shape[0]) * pred.shape[1]
@ -37,43 +42,6 @@ def decode_batch_predictions(pred, num_to_char, max_len = 128):
output.append(d)
return output
def distortion_free_resize(image, img_size):
w, h = img_size
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
# Check tha amount of padding needed to be done.
pad_height = h - tf.shape(image)[0]
pad_width = w - tf.shape(image)[1]
# Only necessary if you want to do same amount of padding on both sides.
if pad_height % 2 != 0:
height = pad_height // 2
pad_height_top = height + 1
pad_height_bottom = height
else:
pad_height_top = pad_height_bottom = pad_height // 2
if pad_width % 2 != 0:
width = pad_width // 2
pad_width_left = width + 1
pad_width_right = width
else:
pad_width_left = pad_width_right = pad_width // 2
image = tf.pad(
image,
paddings=[
[pad_height_top, pad_height_bottom],
[pad_width_left, pad_width_right],
[0, 0],
],
)
image = tf.transpose(image, (1, 0, 2))
image = tf.image.flip_left_right(image)
return image
def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
@ -256,249 +224,58 @@ def return_splitting_point_of_image(image_to_spliited):
return np.sort(peaks_sort_4)
def break_curved_line_into_small_pieces_and_then_merge(img_curved, mask_curved, img_bin_curved=None):
peaks_4 = return_splitting_point_of_image(img_curved)
if len(peaks_4)>0:
def break_curved_line_into_small_pieces_and_then_merge(img_rgb_curved, img_bin_curved, mask_curved):
peaks_4 = return_splitting_point_of_image(img_rgb_curved)
if len(peaks_4):
imgs_tot = []
for ind in range(len(peaks_4)+1):
if ind==0:
img = img_curved[:, :peaks_4[ind], :]
if img_bin_curved is not None:
img_bin = img_bin_curved[:, :peaks_4[ind], :]
mask = mask_curved[:, :peaks_4[ind], :]
elif ind==len(peaks_4):
img = img_curved[:, peaks_4[ind-1]:, :]
if img_bin_curved is not None:
img_bin = img_bin_curved[:, peaks_4[ind-1]:, :]
mask = mask_curved[:, peaks_4[ind-1]:, :]
else:
img = img_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
if img_bin_curved is not None:
img_bin = img_bin_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
mask = mask_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
for left, right in pairwise([None] + peaks_4 + [None]):
img_rgb = img_rgb_curved[:, left: right]
img_bin = img_bin_curved[:, left: right]
mask = mask_curved[:, left: right]
or_ma = get_orientation_moments_of_mask(mask)
if img_bin_curved is not None:
imgs_tot.append([img, mask, or_ma, img_bin] )
else:
imgs_tot.append([img, mask, or_ma] )
imgs_tot.append([img_rgb, img_bin, mask, or_ma])
w_tot_des_list = []
w_tot_des = 0
imgs_deskewed_list = []
imgs_rgb_deskewed_list = []
imgs_bin_deskewed_list = []
for ind in range(len(imgs_tot)):
img_in = imgs_tot[ind][0]
mask_in = imgs_tot[ind][1]
ori_in = imgs_tot[ind][2]
if img_bin_curved is not None:
img_bin_in = imgs_tot[ind][3]
if abs(ori_in)<45:
img_in_des = rotate_image_with_padding(img_in, ori_in, border_value=(255,255,255) )
if img_bin_curved is not None:
for img_rgb_in, img_bin_in, mask_in, ori_in in imgs_tot:
if abs(ori_in) < 45:
img_rgb_in_des = rotate_image_with_padding(img_rgb_in, ori_in, border_value=(255,255,255) )
img_bin_in_des = rotate_image_with_padding(img_bin_in, ori_in, border_value=(255,255,255) )
mask_in_des = rotate_image_with_padding(mask_in, ori_in)
mask_in_des = mask_in_des.astype('uint8')
#new bounding box
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_in_des[:,:,0])
if w_n==0 or h_n==0:
img_in_des = np.copy(img_in)
if img_bin_curved is not None:
img_bin_in_des = np.copy(img_bin_in)
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
if w_relative==0:
w_relative = img_in_des.shape[1]
img_in_des = resize_image(img_in_des, 32, w_relative)
if img_bin_curved is not None:
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
# get new bounding box
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_in_des)
if w_n and h_n:
img_rgb_in_des = img_rgb_in_des[y_n: y_n + h_n, x_n: x_n + w_n]
img_bin_in_des = img_bin_in_des[y_n: y_n + h_n, x_n: x_n + w_n]
else:
mask_in_des = mask_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
img_in_des = img_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
if img_bin_curved is not None:
img_bin_in_des = img_bin_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
if w_relative==0:
w_relative = img_in_des.shape[1]
img_in_des = resize_image(img_in_des, 32, w_relative)
if img_bin_curved is not None:
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
else:
img_in_des = np.copy(img_in)
if img_bin_curved is not None:
img_rgb_in_des = np.copy(img_rgb_in)
img_bin_in_des = np.copy(img_bin_in)
else:
img_rgb_in_des = np.copy(img_rgb_in)
img_bin_in_des = np.copy(img_bin_in)
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
if w_relative==0:
w_relative = img_in_des.shape[1]
img_in_des = resize_image(img_in_des, 32, w_relative)
if img_bin_curved is not None:
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
w_tot_des+=img_in_des.shape[1]
w_tot_des_list.append(img_in_des.shape[1])
imgs_deskewed_list.append(img_in_des)
if img_bin_curved is not None:
h, w = img_rgb_in_des.shape[:2]
new_h = 32
new_w = 32 * w // h
new_w = new_w or w
img_rgb_in_des = resize_image(img_rgb_in_des, new_h, new_w)
img_bin_in_des = resize_image(img_bin_in_des, new_h, new_w)
w_tot_des_list.append(new_w)
imgs_rgb_deskewed_list.append(img_rgb_in_des)
imgs_bin_deskewed_list.append(img_bin_in_des)
img_final_deskewed = np.zeros((32, w_tot_des, 3))+255
if img_bin_curved is not None:
img_bin_final_deskewed = np.zeros((32, w_tot_des, 3))+255
else:
img_bin_final_deskewed = None
img_rgb_final_deskewed = np.ones((new_h, sum(w_tot_des_list), 3)) * 255
img_bin_final_deskewed = np.ones((new_h, sum(w_tot_des_list), 3)) * 255
w_indexer = 0
for ind in range(len(w_tot_des_list)):
img_final_deskewed[:,w_indexer:w_indexer+w_tot_des_list[ind],:] = imgs_deskewed_list[ind][:,:,:]
if img_bin_curved is not None:
img_bin_final_deskewed[:,w_indexer:w_indexer+w_tot_des_list[ind],:] = imgs_bin_deskewed_list[ind][:,:,:]
w_indexer = w_indexer+w_tot_des_list[ind]
return img_final_deskewed, img_bin_final_deskewed
w_indexer2 = w_indexer + w_tot_des_list[ind]
img_rgb_final_deskewed[:, w_indexer: w_indexer2] = imgs_rgb_deskewed_list[ind]
img_bin_final_deskewed[:, w_indexer: w_indexer2] = imgs_bin_deskewed_list[ind]
w_indexer = w_indexer2
return img_rgb_final_deskewed, img_bin_final_deskewed
else:
return img_curved, img_bin_curved
def return_textline_contour_with_added_box_coordinate(textline_contour, box_ind):
textline_contour[:,:,0] += box_ind[2]
textline_contour[:,:,1] += box_ind[0]
return textline_contour
def return_rnn_cnn_ocr_of_given_textlines(image,
all_found_textline_polygons,
all_box_coord,
prediction_model,
b_s_ocr, num_to_char,
curved_line=False):
max_len = 512
padding_token = 299
image_width = 512#max_len * 4
image_height = 32
ind_tot = 0
#cv2.imwrite('./img_out.png', image_page)
ocr_all_textlines = []
cropped_lines_region_indexer = []
cropped_lines_meging_indexing = []
cropped_lines = []
indexer_text_region = 0
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
#ocr_textline_in_textregion = []
if len(ind_poly_first)==0:
cropped_lines_region_indexer.append(indexer_text_region)
cropped_lines_meging_indexing.append(0)
img_fin = np.ones((image_height, image_width, 3))*1
cropped_lines.append(img_fin)
else:
for indexing2, ind_poly in enumerate(ind_poly_first):
cropped_lines_region_indexer.append(indexer_text_region)
if not curved_line:
ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing]
ind_poly = return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
#print(ind_poly_copy)
ind_poly[ind_poly<0] = 0
x, y, w, h = cv2.boundingRect(ind_poly)
w_scaled = w * image_height/float(h)
mask_poly = np.zeros(image.shape)
img_poly_on_img = np.copy(image)
mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))
mask_poly = mask_poly[y:y+h, x:x+w, :]
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
img_crop[mask_poly==0] = 255
if w_scaled < 640:#1.5*image_width:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(0)
else:
splited_images, splited_images_bin = return_textlines_split_if_needed(img_crop, None)
if splited_images:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(splited_images[0],
image_height,
image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(1)
img_fin = preprocess_and_resize_image_for_ocrcnn_model(splited_images[1],
image_height,
image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(-1)
else:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(img_crop,
image_height,
image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(0)
indexer_text_region+=1
extracted_texts = []
n_iterations = math.ceil(len(cropped_lines) / b_s_ocr)
for i in range(n_iterations):
if i==(n_iterations-1):
n_start = i*b_s_ocr
imgs = cropped_lines[n_start:]
imgs = np.array(imgs)
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
else:
n_start = i*b_s_ocr
n_end = (i+1)*b_s_ocr
imgs = cropped_lines[n_start:n_end]
imgs = np.array(imgs).reshape(b_s_ocr, image_height, image_width, 3)
preds = prediction_model.predict(imgs, verbose=0)
pred_texts = decode_batch_predictions(preds, num_to_char)
for ib in range(imgs.shape[0]):
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
extracted_texts.append(pred_texts_ib)
extracted_texts_merged = [extracted_texts[ind]
if cropped_lines_meging_indexing[ind]==0
else extracted_texts[ind]+" "+extracted_texts[ind+1]
if cropped_lines_meging_indexing[ind]==1
else None
for ind in range(len(cropped_lines_meging_indexing))]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
ocr_all_textlines = []
for ind in unique_cropped_lines_region_indexer:
ocr_textline_in_textregion = []
extracted_texts_merged_un = np.array(extracted_texts_merged)[np.array(cropped_lines_region_indexer)==ind]
for it_ind, text_textline in enumerate(extracted_texts_merged_un):
ocr_textline_in_textregion.append(text_textline)
ocr_all_textlines.append(ocr_textline_in_textregion)
return ocr_all_textlines
return img_rgb_curved, img_bin_curved

View file

@ -1,4 +1,5 @@
from typing import List
import os
import pytest
import logging
@ -31,6 +32,8 @@ def run_eynollah_ok_and_check_logs(
subcommand,
*args
]
if 'EYNOLLAH_OPTIONS' in os.environ:
args = os.environ['EYNOLLAH_OPTIONS'].split() + args
if pytestconfig.getoption('verbose') > 0:
args = ['-l', 'DEBUG'] + args
caplog.set_level(logging.INFO)

View file

@ -6,11 +6,12 @@ from ocrd_models.constants import NAMESPACES as NS
"options",
[
[], # defaults
#["--allow_scaling", "--curved-line"],
["--allow_scaling", "--curved-line", "--full-layout"],
["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"],
#["--curved-line"],
["--curved-line", "--full-layout"],
["--curved-line", "--full-layout", "--reading_order_machine_based"],
# -ep ...
# -eoi ...
# --input_binary
# --ignore_page_extraction
# --skip_layout_and_reading_order
], ids=str)
def test_run_eynollah_layout_filename(

View file

@ -30,7 +30,7 @@ def test_run_eynollah_ocr_filename(
'-o', str(outfile.parent),
] + options,
[
# FIXME: ocr has no logging!
'output filename:'
]
)
assert outfile.exists()
@ -57,7 +57,7 @@ def test_run_eynollah_ocr_directory(
'-o', str(outdir),
],
[
# FIXME: ocr has no logging!
'output filename:'
]
)
assert len(list(outdir.iterdir())) == 2

View file

@ -1,16 +1,28 @@
from eynollah.model_zoo import EynollahModelZoo
from eynollah.predictor import Predictor
def test_trocr1(
model_dir,
):
model_zoo = EynollahModelZoo(model_dir)
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
model_zoo.load_models('trocr_processor')
proc = model_zoo.get('trocr_processor')
assert isinstance(proc, TrOCRProcessor)
model_zoo.load_models(['ocr', 'tr'])
model_zoo.load_models(('ocr', 'tr'))
model = model_zoo.get('ocr')
assert isinstance(model, VisionEncoderDecoderModel)
assert isinstance(model, Predictor)
shape = model.input_shape
assert len(shape) == 3
except ImportError:
pass
def test_cnnrnnocr1(
model_dir,
):
model_zoo = EynollahModelZoo(model_dir)
try:
model_zoo.load_models('ocr')
model = model_zoo.get('ocr')
assert isinstance(model, Predictor)
shape = model.input_shape
assert len(shape) == 4
except ImportError:
pass