This commit is contained in:
vahidrezanezhad 2025-08-28 11:37:36 +02:00 committed by GitHub
commit 87938fe42b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 3833 additions and 660 deletions

View file

@ -4,4 +4,6 @@ numpy <1.24.0
scikit-learn >= 0.23.2
tensorflow < 2.13
numba <= 0.58.1
scikit-image
loky
biopython

View file

@ -3,6 +3,8 @@ import click
from ocrd_utils import initLogging, getLevelName, getLogger
from eynollah.eynollah import Eynollah, Eynollah_ocr
from eynollah.sbb_binarize import SbbBinarizer
from eynollah.image_enhancer import Enhancer
from eynollah.mb_ro_on_layout import machine_based_reading_order_on_layout
@click.group()
def main():
@ -12,38 +14,37 @@ def main():
@click.option(
"--dir_xml",
"-dx",
help="directory of GT page-xml files",
help="directory of page-xml files",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--dir_out_modal_image",
"-domi",
help="directory where ground truth images would be written",
"--xml_file",
"-xml",
help="xml filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--dir_out",
"-do",
help="directory for output images",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--dir_out_classes",
"-docl",
help="directory where ground truth classes would be written",
"--model",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--input_height",
"-ih",
help="input height",
)
@click.option(
"--input_width",
"-iw",
help="input width",
)
@click.option(
"--min_area_size",
"-min",
help="min area size of regions considered for reading order training.",
)
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size):
xml_files_ind = os.listdir(dir_xml)
def machine_based_reading_order(dir_xml, xml_file, dir_out, model):
raedingorder_object = machine_based_reading_order_on_layout(model, dir_out=dir_out, logger=getLogger('enhancement'))
if dir_xml:
raedingorder_object.run(dir_in=dir_xml)
else:
raedingorder_object.run(xml_filename=xml_file)
@main.command()
@click.option('--patches/--no-patches', default=True, help='by enabling this parameter you let the model to see the image in patches.')
@ -70,6 +71,81 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out)
@main.command()
@click.option(
"--image",
"-i",
help="image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--out",
"-o",
help="directory to write output xml data",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--dir_in",
"-di",
help="directory of images",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--model",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
)
@click.option(
"--save_org_scale/--no_save_org_scale",
"-sos/-nosos",
is_flag=True,
help="if this parameter set to true, this tool will save the enhanced image in org scale.",
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)
def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_lower, save_org_scale, log_level):
initLogging()
if log_level:
getLogger('enhancement').setLevel(getLevelName(log_level))
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
enhancer_object = Enhancer(
model,
logger=getLogger('enhancement'),
dir_out=out,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
save_org_scale=save_org_scale,
)
if dir_in:
enhancer_object.run(dir_in=dir_in, overwrite=overwrite)
else:
enhancer_object.run(image_filename=image, overwrite=overwrite)
@main.command()
@click.option(
@ -225,6 +301,17 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out)
is_flag=True,
help="if this parameter set to true, this tool will try to do ocr",
)
@click.option(
"--transformer_ocr",
"-tr/-notr",
is_flag=True,
help="if this parameter set to true, this tool will apply transformer ocr",
)
@click.option(
"--batch_size_ocr",
"-bs_ocr",
help="number of inference batch size of ocr model. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
)
@click.option(
"--num_col_upper",
"-ncu",
@ -235,6 +322,16 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out)
"-ncl",
help="upper limit of columns in document image",
)
@click.option(
"--threshold_art_class_layout",
"-tharl",
help="threshold of artifical class in the case of layout detection. The default value is 0.1",
)
@click.option(
"--threshold_art_class_textline",
"-thart",
help="threshold of artifical class in the case of textline detection. The default value is 0.1",
)
@click.option(
"--skip_layout_and_reading_order",
"-slro/-noslro",
@ -248,7 +345,7 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out)
help="Override log level globally to this",
)
def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, num_col_upper, num_col_lower, skip_layout_and_reading_order, ignore_page_extraction, log_level):
def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level):
initLogging()
if log_level:
getLogger('eynollah').setLevel(getLevelName(log_level))
@ -295,9 +392,13 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
ignore_page_extraction=ignore_page_extraction,
reading_order_machine_based=reading_order_machine_based,
do_ocr=do_ocr,
transformer_ocr=transformer_ocr,
batch_size_ocr=batch_size_ocr,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
skip_layout_and_reading_order=skip_layout_and_reading_order,
threshold_art_class_textline=threshold_art_class_textline,
threshold_art_class_layout=threshold_art_class_layout,
)
if dir_in:
eynollah.run(dir_in=dir_in, overwrite=overwrite)
@ -306,6 +407,18 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
@main.command()
@click.option(
"--image",
"-i",
help="image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--dir_in",
"-di",
@ -342,7 +455,11 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--model_name",
help="Specific model file path to use for OCR",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--tr_ocr",
@ -362,18 +479,27 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
is_flag=True,
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
)
@click.option(
"--draw_texts_on_image",
"-dtoi/-ndtoi",
is_flag=True,
help="if this parameter set to true, the predicted texts will be displayed on an image.",
)
@click.option(
"--prediction_with_both_of_rgb_and_bin",
"-brb/-nbrb",
is_flag=True,
help="If this parameter is set to True, the prediction will be performed using both RGB and binary images. However, this does not necessarily improve results; it may be beneficial for certain document images.",
)
@click.option(
"--batch_size",
"-bs",
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
)
@click.option(
"--dataset_abbrevation",
"-ds_pref",
help="in the case of extracting textline and text from a xml GT file user can add an abbrevation of dataset name to generated dataset",
)
@click.option(
"--min_conf_value_of_textline_text",
"-min_conf",
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
)
@click.option(
"--log_level",
"-l",
@ -381,24 +507,37 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
help="Override log level globally to this",
)
def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, log_level):
def ocr(image, overwrite, dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, model_name, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, prediction_with_both_of_rgb_and_bin, batch_size, dataset_abbrevation, min_conf_value_of_textline_text, log_level):
initLogging()
if log_level:
getLogger('eynollah').setLevel(getLevelName(log_level))
assert not model or not model_name, "model directory -m can not be set alongside specific model name --model_name"
assert not export_textline_images_and_text or not tr_ocr, "Exporting textline and text -etit can not be set alongside transformer ocr -tr_ocr"
assert not export_textline_images_and_text or not model, "Exporting textline and text -etit can not be set alongside model -m"
assert not export_textline_images_and_text or not batch_size, "Exporting textline and text -etit can not be set alongside batch size -bs"
assert not export_textline_images_and_text or not dir_in_bin, "Exporting textline and text -etit can not be set alongside directory of bin images -dib"
assert not export_textline_images_and_text or not dir_out_image_text, "Exporting textline and text -etit can not be set alongside directory of images with predicted text -doit"
assert not export_textline_images_and_text or not prediction_with_both_of_rgb_and_bin, "Exporting textline and text -etit can not be set alongside prediction with both rgb and bin -brb"
assert (bool(image) ^ bool(dir_in)), "Either -i (single image) or -di (directory) must be provided, but not both."
eynollah_ocr = Eynollah_ocr(
image_filename=image,
dir_xmls=dir_xmls,
dir_out_image_text=dir_out_image_text,
dir_in=dir_in,
dir_in_bin=dir_in_bin,
dir_out=out,
dir_models=model,
model_name=model_name,
tr_ocr=tr_ocr,
export_textline_images_and_text=export_textline_images_and_text,
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
draw_texts_on_image=draw_texts_on_image,
prediction_with_both_of_rgb_and_bin=prediction_with_both_of_rgb_and_bin,
batch_size=batch_size,
pref_of_dataset=dataset_abbrevation,
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
)
eynollah_ocr.run()
eynollah_ocr.run(overwrite=overwrite)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,735 @@
"""
Image enhancer. The output can be written as same scale of input or in new predicted scale.
"""
from logging import Logger
from difflib import SequenceMatcher as sq
from PIL import Image, ImageDraw, ImageFont
import math
import os
import sys
import time
from typing import Optional
import atexit
import warnings
from functools import partial
from pathlib import Path
from multiprocessing import cpu_count
import gc
import copy
from loky import ProcessPoolExecutor
import xml.etree.ElementTree as ET
import cv2
import numpy as np
from ocrd import OcrdPage
from ocrd_utils import getLogger, tf_disable_interactive_logs
import statistics
from tensorflow.keras.models import load_model
from .utils.resize import resize_image
from .utils import (
crop_image_inside_box
)
DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8)
class Enhancer:
def __init__(
self,
dir_models : str,
dir_out : Optional[str] = None,
num_col_upper : Optional[int] = None,
num_col_lower : Optional[int] = None,
save_org_scale : bool = False,
logger : Optional[Logger] = None,
):
self.dir_out = dir_out
self.input_binary = False
self.light_version = False
self.save_org_scale = save_org_scale
if num_col_upper:
self.num_col_upper = int(num_col_upper)
else:
self.num_col_upper = num_col_upper
if num_col_lower:
self.num_col_lower = int(num_col_lower)
else:
self.num_col_lower = num_col_lower
self.logger = logger if logger else getLogger('enhancement')
# for parallelization of CPU-intensive tasks:
self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200)
atexit.register(self.executor.shutdown)
self.dir_models = dir_models
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
self.model_page_dir = dir_models + "/eynollah-page-extraction_20210425"
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {}
t_c0 = time.time()
if image_filename:
ret['img'] = cv2.imread(image_filename)
if self.light_version:
self.dpi = 100
else:
self.dpi = 0#check_dpi(image_filename)
else:
ret['img'] = pil2cv(image_pil)
if self.light_version:
self.dpi = 100
else:
self.dpi = 0#check_dpi(image_pil)
ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY)
for prefix in ('', '_grayscale'):
ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8)
self._imgs = ret
if dpi is not None:
self.dpi = dpi
def reset_file_name_dir(self, image_filename):
t_c = time.time()
self.cache_images(image_filename=image_filename)
self.output_filename = os.path.join(self.dir_out, Path(image_filename).stem +'.png')
def imread(self, grayscale=False, uint8=True):
key = 'img'
if grayscale:
key += '_grayscale'
if uint8:
key += '_uint8'
return self._imgs[key].copy()
def isNaN(self, num):
return num != num
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement")
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < img_width_model:
img = cv2.resize(img, (img_height_model, img.shape[0]), interpolation=cv2.INTER_NEAREST)
margin = int(0.1 * img_width_model)
width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin
img = img / 255.
img_h = img.shape[0]
img_w = img.shape[1]
prediction_true = np.zeros((img_h, img_w, 3))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
for i in range(nxf):
for j in range(nyf):
if i == 0:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
else:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
if j == 0:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
else:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - img_width_model
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + 0:index_x_u - margin] = \
seg[0:-margin or None,
0:-margin or None]
elif i == nxf - 1 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + margin:index_x_u - 0] = \
seg[margin:,
margin:]
elif i == 0 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + 0:index_x_u - margin] = \
seg[margin:,
0:-margin or None]
elif i == nxf - 1 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + margin:index_x_u - 0] = \
seg[0:-margin or None,
margin:]
elif i == 0 and j != 0 and j != nyf - 1:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + 0:index_x_u - margin] = \
seg[margin:-margin or None,
0:-margin or None]
elif i == nxf - 1 and j != 0 and j != nyf - 1:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + margin:index_x_u - 0] = \
seg[margin:-margin or None,
margin:]
elif i != 0 and i != nxf - 1 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + margin:index_x_u - margin] = \
seg[0:-margin or None,
margin:-margin or None]
elif i != 0 and i != nxf - 1 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + margin:index_x_u - margin] = \
seg[margin:,
margin:-margin or None]
else:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + margin:index_x_u - margin] = \
seg[margin:-margin or None,
margin:-margin or None]
prediction_true = prediction_true.astype(int)
return prediction_true
def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred):
self.logger.debug("enter calculate_width_height_by_columns")
if num_col == 1:
img_w_new = 2000
elif num_col == 2:
img_w_new = 2400
elif num_col == 3:
img_w_new = 3000
elif num_col == 4:
img_w_new = 4000
elif num_col == 5:
img_w_new = 5000
elif num_col == 6:
img_w_new = 6500
else:
img_w_new = width_early
img_h_new = img_w_new * img.shape[0] // img.shape[1]
if img_h_new >= 8000:
img_new = np.copy(img)
num_column_is_classified = False
else:
img_new = resize_image(img, img_h_new, img_w_new)
num_column_is_classified = True
return img_new, num_column_is_classified
def early_page_for_num_of_column_classification(self,img_bin):
self.logger.debug("enter early_page_for_num_of_column_classification")
if self.input_binary:
img = np.copy(img_bin).astype(np.uint8)
else:
img = self.imread()
img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
thresh = cv2.dilate(thresh, KERNEL, iterations=3)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
if len(contours)>0:
cnt_size = np.array([cv2.contourArea(contours[j])
for j in range(len(contours))])
cnt = contours[np.argmax(cnt_size)]
box = cv2.boundingRect(cnt)
else:
box = [0, 0, img.shape[1], img.shape[0]]
cropped_page, page_coord = crop_image_inside_box(box, img)
self.logger.debug("exit early_page_for_num_of_column_classification")
return cropped_page, page_coord
def calculate_width_height_by_columns_1_2(self, img, num_col, width_early, label_p_pred):
self.logger.debug("enter calculate_width_height_by_columns")
if num_col == 1:
img_w_new = 1000
else:
img_w_new = 1300
img_h_new = img_w_new * img.shape[0] // img.shape[1]
if label_p_pred[0][int(num_col - 1)] < 0.9 and img_w_new < width_early:
img_new = np.copy(img)
num_column_is_classified = False
#elif label_p_pred[0][int(num_col - 1)] < 0.8 and img_h_new >= 8000:
elif img_h_new >= 8000:
img_new = np.copy(img)
num_column_is_classified = False
else:
img_new = resize_image(img, img_h_new, img_w_new)
num_column_is_classified = True
return img_new, num_column_is_classified
def resize_and_enhance_image_with_column_classifier(self, light_version):
self.logger.debug("enter resize_and_enhance_image_with_column_classifier")
dpi = 0#self.dpi
self.logger.info("Detected %s DPI", dpi)
if self.input_binary:
img = self.imread()
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
img= np.copy(prediction_bin)
img_bin = prediction_bin
else:
img = self.imread()
self.h_org, self.w_org = img.shape[:2]
img_bin = None
width_early = img.shape[1]
t1 = time.time()
_, page_coord = self.early_page_for_num_of_column_classification(img_bin)
self.image_page_org_size = img[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3], :]
self.page_coord = page_coord
if self.num_col_upper and not self.num_col_lower:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
elif self.num_col_lower and not self.num_col_upper:
num_col = self.num_col_lower
label_p_pred = [np.ones(6)]
elif not self.num_col_upper and not self.num_col_lower:
if self.input_binary:
img_in = np.copy(img)
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = img_in.reshape(1, 448, 448, 3)
else:
img_1ch = self.imread(grayscale=True)
width_early = img_1ch.shape[1]
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
if self.input_binary:
img_in = np.copy(img)
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = img_in.reshape(1, 448, 448, 3)
else:
img_1ch = self.imread(grayscale=True)
width_early = img_1ch.shape[1]
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
if num_col > self.num_col_upper:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
if num_col < self.num_col_lower:
num_col = self.num_col_lower
label_p_pred = [np.ones(6)]
else:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))
if dpi < DPI_THRESHOLD:
if light_version and num_col in (1,2):
img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2(
img, num_col, width_early, label_p_pred)
else:
img_new, num_column_is_classified = self.calculate_width_height_by_columns(
img, num_col, width_early, label_p_pred)
if light_version:
image_res = np.copy(img_new)
else:
image_res = self.predict_enhancement(img_new)
is_image_enhanced = True
else:
num_column_is_classified = True
image_res = np.copy(img)
is_image_enhanced = False
self.logger.debug("exit resize_and_enhance_image_with_column_classifier")
return is_image_enhanced, img, image_res, num_col, num_column_is_classified, img_bin
def do_prediction(
self, patches, img, model,
n_batch_inference=1, marginal_of_patch_percent=0.1,
thresholding_for_some_classes_in_light_version=False,
thresholding_for_artificial_class_in_light_version=False, thresholding_for_fl_light_version=False, threshold_art_class_textline=0.1):
self.logger.debug("enter do_prediction")
img_height_model = model.layers[-1].output_shape[1]
img_width_model = model.layers[-1].output_shape[2]
if not patches:
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / float(255.0)
img = resize_image(img, img_height_model, img_width_model)
label_p_pred = model.predict(img[np.newaxis], verbose=0)
seg = np.argmax(label_p_pred, axis=3)[0]
if thresholding_for_artificial_class_in_light_version:
seg_art = label_p_pred[0,:,:,2]
seg_art[seg_art<threshold_art_class_textline] = 0
seg_art[seg_art>0] =1
skeleton_art = skeletonize(seg_art)
skeleton_art = skeleton_art*1
seg[skeleton_art==1]=2
if thresholding_for_fl_light_version:
seg_header = label_p_pred[0,:,:,2]
seg_header[seg_header<0.2] = 0
seg_header[seg_header>0] =1
seg[seg_header==1]=2
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8)
return prediction_true
if img.shape[0] < img_height_model:
img = resize_image(img, img_height_model, img.shape[1])
if img.shape[1] < img_width_model:
img = resize_image(img, img.shape[0], img_width_model)
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
margin = int(marginal_of_patch_percent * img_height_model)
width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin
img = img / 255.
#img = img.astype(np.float16)
img_h = img.shape[0]
img_w = img.shape[1]
prediction_true = np.zeros((img_h, img_w, 3))
mask_true = np.zeros((img_h, img_w))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3))
for i in range(nxf):
for j in range(nyf):
if i == 0:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
else:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
if j == 0:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
else:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - img_width_model
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - img_height_model
list_i_s.append(i)
list_j_s.append(j)
list_x_u.append(index_x_u)
list_x_d.append(index_x_d)
list_y_d.append(index_y_d)
list_y_u.append(index_y_u)
img_patch[batch_indexer,:,:,:] = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
batch_indexer += 1
if (batch_indexer == n_batch_inference or
# last batch
i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch, verbose=0)
seg = np.argmax(label_p_pred, axis=3)
if thresholding_for_some_classes_in_light_version:
seg_not_base = label_p_pred[:,:,:,4]
seg_not_base[seg_not_base>0.03] =1
seg_not_base[seg_not_base<1] =0
seg_line = label_p_pred[:,:,:,3]
seg_line[seg_line>0.1] =1
seg_line[seg_line<1] =0
seg_background = label_p_pred[:,:,:,0]
seg_background[seg_background>0.25] =1
seg_background[seg_background<1] =0
seg[seg_not_base==1]=4
seg[seg_background==1]=0
seg[(seg_line==1) & (seg==0)]=3
if thresholding_for_artificial_class_in_light_version:
seg_art = label_p_pred[:,:,:,2]
seg_art[seg_art<threshold_art_class_textline] = 0
seg_art[seg_art>0] =1
##seg[seg_art==1]=2
indexer_inside_batch = 0
for i_batch, j_batch in zip(list_i_s, list_j_s):
seg_in = seg[indexer_inside_batch]
if thresholding_for_artificial_class_in_light_version:
seg_in_art = seg_art[indexer_inside_batch]
index_y_u_in = list_y_u[indexer_inside_batch]
index_y_d_in = list_y_d[indexer_inside_batch]
index_x_u_in = list_x_u[indexer_inside_batch]
index_x_d_in = list_x_d[indexer_inside_batch]
if i_batch == 0 and j_batch == 0:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin] = \
seg_in[0:-margin or None,
0:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin, 1] = \
seg_in_art[0:-margin or None,
0:-margin or None]
elif i_batch == nxf - 1 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[margin:,
margin:,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - 0, 1] = \
seg_in_art[margin:,
margin:]
elif i_batch == 0 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + 0:index_x_u_in - margin] = \
seg_in[margin:,
0:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + 0:index_x_u_in - margin, 1] = \
seg_in_art[margin:,
0:-margin or None]
elif i_batch == nxf - 1 and j_batch == 0:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[0:-margin or None,
margin:,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0, 1] = \
seg_in_art[0:-margin or None,
margin:]
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin] = \
seg_in[margin:-margin or None,
0:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin, 1] = \
seg_in_art[margin:-margin or None,
0:-margin or None]
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[margin:-margin or None,
margin:,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0, 1] = \
seg_in_art[margin:-margin or None,
margin:]
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[0:-margin or None,
margin:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin, 1] = \
seg_in_art[0:-margin or None,
margin:-margin or None]
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[margin:,
margin:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - margin, 1] = \
seg_in_art[margin:,
margin:-margin or None]
else:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[margin:-margin or None,
margin:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin, 1] = \
seg_in_art[margin:-margin or None,
margin:-margin or None]
indexer_inside_batch += 1
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch[:] = 0
prediction_true = prediction_true.astype(np.uint8)
if thresholding_for_artificial_class_in_light_version:
kernel_min = np.ones((3, 3), np.uint8)
prediction_true[:,:,0][prediction_true[:,:,0]==2] = 0
skeleton_art = skeletonize(prediction_true[:,:,1])
skeleton_art = skeleton_art*1
skeleton_art = skeleton_art.astype('uint8')
skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1)
prediction_true[:,:,0][skeleton_art==1]=2
#del model
gc.collect()
return prediction_true
def run_enhancement(self, light_version):
t_in = time.time()
self.logger.info("Resizing and enhancing image...")
is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \
self.resize_and_enhance_image_with_column_classifier(light_version)
self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ')
return img_res, is_image_enhanced, num_col_classifier, num_column_is_classified
def run_single(self):
t0 = time.time()
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False)
return img_res
def run(self, image_filename : Optional[str] = None, dir_in : Optional[str] = None, overwrite : bool = False):
"""
Get image and scales, then extract the page of scanned image
"""
self.logger.debug("enter run")
t0_tot = time.time()
if dir_in:
self.ls_imgs = os.listdir(dir_in)
elif image_filename:
self.ls_imgs = [image_filename]
else:
raise ValueError("run requires either a single image filename or a directory")
for img_filename in self.ls_imgs:
self.logger.info(img_filename)
t0 = time.time()
self.reset_file_name_dir(os.path.join(dir_in or "", img_filename))
#print("text region early -11 in %.1fs", time.time() - t0)
if os.path.exists(self.output_filename):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", self.output_filename)
else:
self.logger.warning("will skip input for existing output file '%s'", self.output_filename)
continue
image_enhanced = self.run_single()
if self.save_org_scale:
image_enhanced = resize_image(image_enhanced, self.h_org, self.w_org)
cv2.imwrite(self.output_filename, image_enhanced)

File diff suppressed because it is too large Load diff

View file

@ -992,7 +992,7 @@ def check_any_text_region_in_model_one_is_main_or_header_light(
(regions_model_full[:,:,0]==2)).sum()
pixels_main = all_pixels - pixels_header
if (pixels_header/float(pixels_main)>=0.3) and ( (length_con[ii]/float(height_con[ii]) )>=1.3 ):
if ( (pixels_header/float(pixels_main)>=0.6) and ( (length_con[ii]/float(height_con[ii]) )>=1.3 ) and ( (length_con[ii]/float(height_con[ii]) )<=3 )) or ( (pixels_header/float(pixels_main)>=0.3) and ( (length_con[ii]/float(height_con[ii]) )>=3 ) ):
regions_model_1[:,:][(regions_model_1[:,:]==1) & (img[:,:,0]==255) ]=2
contours_only_text_parent_head.append(con)
if contours_only_text_parent_d_ordered is not None:
@ -1801,8 +1801,8 @@ def return_boxes_of_images_by_order_of_reading_new(
#print(y_type_2_up,x_starting_up,x_ending_up,'didid')
nodes_in = []
for ij in range(len(x_starting_up)):
nodes_in = nodes_in + list(range(x_starting_up[ij],
x_ending_up[ij]))
nodes_in = nodes_in + list(range(int(x_starting_up[ij]),
int(x_ending_up[ij])))
nodes_in = np.unique(nodes_in)
#print(nodes_in,'nodes_in')
@ -1825,8 +1825,8 @@ def return_boxes_of_images_by_order_of_reading_new(
elif len(y_diff_main_separator_up)==0:
nodes_in = []
for ij in range(len(x_starting_up)):
nodes_in = nodes_in + list(range(x_starting_up[ij],
x_ending_up[ij]))
nodes_in = nodes_in + list(range(int(x_starting_up[ij]),
int(x_ending_up[ij])))
nodes_in = np.unique(nodes_in)
#print(nodes_in,'nodes_in2')
#print(np.array(range(len(peaks_neg_tot)-1)),'np.array(range(len(peaks_neg_tot)-1))')
@ -1866,8 +1866,8 @@ def return_boxes_of_images_by_order_of_reading_new(
columns_covered_by_mothers = []
for dj in range(len(x_start_without_mother)):
columns_covered_by_mothers = columns_covered_by_mothers + \
list(range(x_start_without_mother[dj],
x_end_without_mother[dj]))
list(range(int(x_start_without_mother[dj]),
int(x_end_without_mother[dj])))
columns_covered_by_mothers = list(set(columns_covered_by_mothers))
all_columns=np.arange(len(peaks_neg_tot)-1)
@ -1909,8 +1909,8 @@ def return_boxes_of_images_by_order_of_reading_new(
columns_covered_by_mothers = []
for dj in range(len(x_start_without_mother)):
columns_covered_by_mothers = columns_covered_by_mothers + \
list(range(x_start_without_mother[dj],
x_end_without_mother[dj]))
list(range(int(x_start_without_mother[dj]),
int(x_end_without_mother[dj])))
columns_covered_by_mothers = list(set(columns_covered_by_mothers))
all_columns=np.arange(len(peaks_neg_tot)-1)
@ -1926,8 +1926,8 @@ def return_boxes_of_images_by_order_of_reading_new(
columns_covered_by_with_child_no_mothers = []
for dj in range(len(x_end_with_child_without_mother)):
columns_covered_by_with_child_no_mothers = columns_covered_by_with_child_no_mothers + \
list(range(x_start_with_child_without_mother[dj],
x_end_with_child_without_mother[dj]))
list(range(int(x_start_with_child_without_mother[dj]),
int(x_end_with_child_without_mother[dj])))
columns_covered_by_with_child_no_mothers = list(set(columns_covered_by_with_child_no_mothers))
all_columns = np.arange(len(peaks_neg_tot)-1)
@ -1970,8 +1970,8 @@ def return_boxes_of_images_by_order_of_reading_new(
columns_covered_by_mothers = []
for dj in range(len(x_starting_all_between_nm_wc)):
columns_covered_by_mothers = columns_covered_by_mothers + \
list(range(x_starting_all_between_nm_wc[dj],
x_ending_all_between_nm_wc[dj]))
list(range(int(x_starting_all_between_nm_wc[dj]),
int(x_ending_all_between_nm_wc[dj])))
columns_covered_by_mothers = list(set(columns_covered_by_mothers))
all_columns=np.arange(i_s_nc, x_end_biggest_column)
@ -1979,8 +1979,8 @@ def return_boxes_of_images_by_order_of_reading_new(
should_longest_line_be_extended=0
if (len(x_diff_all_between_nm_wc) > 0 and
set(list(range(x_starting_all_between_nm_wc[biggest],
x_ending_all_between_nm_wc[biggest])) +
set(list(range(int(x_starting_all_between_nm_wc[biggest]),
int(x_ending_all_between_nm_wc[biggest]))) +
list(columns_not_covered)) != set(all_columns)):
should_longest_line_be_extended=1
index_lines_so_close_to_top_separator = \
@ -2012,7 +2012,7 @@ def return_boxes_of_images_by_order_of_reading_new(
x_ending_all_between_nm_wc = np.append(x_ending_all_between_nm_wc, np.array(columns_not_covered) + 1)
ind_args_between=np.arange(len(x_ending_all_between_nm_wc))
for column in range(i_s_nc, x_end_biggest_column):
for column in range(int(i_s_nc), int(x_end_biggest_column)):
ind_args_in_col=ind_args_between[x_starting_all_between_nm_wc==column]
#print('babali2')
#print(ind_args_in_col,'ind_args_in_col')
@ -2064,7 +2064,7 @@ def return_boxes_of_images_by_order_of_reading_new(
x_end_itself=x_end_copy.pop(il)
#print(y_copy,'y_copy2')
for column in range(x_start_itself, x_end_itself+1):
for column in range(int(x_start_itself), int(x_end_itself)+1):
#print(column,'cols')
y_in_cols=[]
for yic in range(len(y_copy)):
@ -2095,11 +2095,11 @@ def return_boxes_of_images_by_order_of_reading_new(
all_columns = np.arange(len(peaks_neg_tot)-1)
columns_covered_by_lines_covered_more_than_2col = []
for dj in range(len(x_starting)):
if set(list(range(x_starting[dj],x_ending[dj]))) == set(all_columns):
if set(list(range(int(x_starting[dj]),int(x_ending[dj]) ))) == set(all_columns):
pass
else:
columns_covered_by_lines_covered_more_than_2col = columns_covered_by_lines_covered_more_than_2col + \
list(range(x_starting[dj],x_ending[dj]))
list(range(int(x_starting[dj]),int(x_ending[dj]) ))
columns_covered_by_lines_covered_more_than_2col = list(set(columns_covered_by_lines_covered_more_than_2col))
columns_not_covered = list(set(all_columns) - set(columns_covered_by_lines_covered_more_than_2col))
@ -2124,7 +2124,7 @@ def return_boxes_of_images_by_order_of_reading_new(
x_ending = np.append(x_ending, np.array(columns_not_covered) + 1)
ind_args=np.array(range(len(y_type_2)))
#ind_args=np.array(ind_args)
for column in range(len(peaks_neg_tot)-1):
#print(column,'column')
ind_args_in_col=ind_args[x_starting==column]
@ -2155,8 +2155,7 @@ def return_boxes_of_images_by_order_of_reading_new(
x_start_itself=x_start_copy.pop(il)
x_end_itself=x_end_copy.pop(il)
#print(y_copy,'y_copy2')
for column in range(x_start_itself, x_end_itself+1):
for column in range(int(x_start_itself), int(x_end_itself)+1):
#print(column,'cols')
y_in_cols=[]
for yic in range(len(y_copy)):

View file

@ -10,7 +10,6 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
mask_marginals=np.zeros((text_with_lines.shape[0],text_with_lines.shape[1]))
mask_marginals=mask_marginals.astype(np.uint8)
text_with_lines=text_with_lines.astype(np.uint8)
##text_with_lines=cv2.erode(text_with_lines,self.kernel,iterations=3)
@ -26,8 +25,12 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
text_with_lines=resize_image(text_with_lines,int(text_with_lines.shape[0]*1.8),text_with_lines.shape[1])
text_with_lines=cv2.erode(text_with_lines,kernel,iterations=7)
text_with_lines=resize_image(text_with_lines,text_with_lines_eroded.shape[0],text_with_lines_eroded.shape[1])
if light_version:
kernel_hor = np.ones((1, 5), dtype=np.uint8)
text_with_lines = cv2.erode(text_with_lines,kernel_hor,iterations=6)
text_with_lines_y=text_with_lines.sum(axis=0)
text_with_lines_y_eroded=text_with_lines_eroded.sum(axis=0)
@ -40,8 +43,10 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
elif thickness_along_y_percent>=30 and thickness_along_y_percent<50:
min_textline_thickness=20
else:
min_textline_thickness=40
if light_version:
min_textline_thickness=45
else:
min_textline_thickness=40
if thickness_along_y_percent>=14:

View file

@ -5,6 +5,8 @@ import numpy as np
import cv2
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
from multiprocessing import Process, Queue, cpu_count
from multiprocessing import Pool
from .rotate import rotate_image
from .resize import resize_image
from .contour import (
@ -1466,7 +1468,7 @@ def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
main_page=False, logger=None, plotter=None, map=map):
if main_page and plotter:
plotter.save_plot_of_textline_density(img_patch_org)
img_int=np.zeros((img_patch_org.shape[0],img_patch_org.shape[1]))
img_int[:,:]=img_patch_org[:,:]#img_patch_org[:,:,0]
@ -1487,7 +1489,7 @@ def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100,
angles = np.linspace(angle - 22.5, angle + 22.5, n_tot_angles)
angle = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
elif main_page:
angles = np.linspace(-12, 12, n_tot_angles)#np.array([0 , 45 , 90 , -45])
angles = np.array (list(np.linspace(-12, -7, int(n_tot_angles/4))) + list(np.linspace(-6, 6, n_tot_angles- 2* int(n_tot_angles/4))) + list(np.linspace(7, 12, int(n_tot_angles/4))))#np.linspace(-12, 12, n_tot_angles)#np.array([0 , 45 , 90 , -45])
angle = get_smallest_skew(img_resized, sigma_des, angles, map=map, logger=logger, plotter=plotter)
early_slope_edge=11
@ -1526,6 +1528,107 @@ def get_smallest_skew(img, sigma_des, angles, logger=None, plotter=None, map=map
angle = 0
return angle
def return_deskew_slop_old_mp(img_patch_org, sigma_des,n_tot_angles=100,
main_page=False, logger=None, plotter=None):
if main_page and plotter:
plotter.save_plot_of_textline_density(img_patch_org)
img_int=np.zeros((img_patch_org.shape[0],img_patch_org.shape[1]))
img_int[:,:]=img_patch_org[:,:]#img_patch_org[:,:,0]
max_shape=np.max(img_int.shape)
img_resized=np.zeros((int( max_shape*(1.1) ) , int( max_shape*(1.1) ) ))
onset_x=int((img_resized.shape[1]-img_int.shape[1])/2.)
onset_y=int((img_resized.shape[0]-img_int.shape[0])/2.)
img_resized[ onset_y:onset_y+img_int.shape[0] , onset_x:onset_x+img_int.shape[1] ]=img_int[:,:]
if main_page and img_patch_org.shape[1] > img_patch_org.shape[0]:
angles = np.array([-45, 0, 45, 90,])
angle = get_smallest_skew_omp(img_resized, sigma_des, angles, plotter=plotter)
angles = np.linspace(angle - 22.5, angle + 22.5, n_tot_angles)
angle = get_smallest_skew_omp(img_resized, sigma_des, angles, plotter=plotter)
elif main_page:
angles = np.linspace(-12, 12, n_tot_angles)#np.array([0 , 45 , 90 , -45])
angle = get_smallest_skew_omp(img_resized, sigma_des, angles, plotter=plotter)
early_slope_edge=11
if abs(angle) > early_slope_edge:
if angle < 0:
angles = np.linspace(-90, -12, n_tot_angles)
else:
angles = np.linspace(90, 12, n_tot_angles)
angle = get_smallest_skew_omp(img_resized, sigma_des, angles, plotter=plotter)
else:
angles = np.linspace(-25, 25, int(0.5 * n_tot_angles) + 10)
angle = get_smallest_skew_omp(img_resized, sigma_des, angles, plotter=plotter)
early_slope_edge=22
if abs(angle) > early_slope_edge:
if angle < 0:
angles = np.linspace(-90, -25, int(0.5 * n_tot_angles) + 10)
else:
angles = np.linspace(90, 25, int(0.5 * n_tot_angles) + 10)
angle = get_smallest_skew_omp(img_resized, sigma_des, angles, plotter=plotter)
return angle
def do_image_rotation_omp(queue_of_all_params,angles_per_process, img_resized, sigma_des):
vars_per_each_subprocess = []
angles_per_each_subprocess = []
for mv in range(len(angles_per_process)):
img_rot=rotate_image(img_resized,angles_per_process[mv])
img_rot[img_rot!=0]=1
try:
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
except:
var_spectrum=0
vars_per_each_subprocess.append(var_spectrum)
angles_per_each_subprocess.append(angles_per_process[mv])
queue_of_all_params.put([vars_per_each_subprocess, angles_per_each_subprocess])
def get_smallest_skew_omp(img_resized, sigma_des, angles, plotter=None):
num_cores = cpu_count()
queue_of_all_params = Queue()
processes = []
nh = np.linspace(0, len(angles), num_cores + 1)
for i in range(num_cores):
angles_per_process = angles[int(nh[i]) : int(nh[i + 1])]
processes.append(Process(target=do_image_rotation_omp, args=(queue_of_all_params, angles_per_process, img_resized, sigma_des)))
for i in range(num_cores):
processes[i].start()
var_res=[]
all_angles = []
for i in range(num_cores):
list_all_par = queue_of_all_params.get(True)
vars_for_subprocess = list_all_par[0]
angles_sub_process = list_all_par[1]
for j in range(len(vars_for_subprocess)):
var_res.append(vars_for_subprocess[j])
all_angles.append(angles_sub_process[j])
for i in range(num_cores):
processes[i].join()
if plotter:
plotter.save_plot_of_rotation_angle(all_angles, var_res)
try:
var_res=np.array(var_res)
ang_int=all_angles[np.argmax(var_res)]#angels_sorted[arg_final]#angels[arg_sort_early[arg_sort[arg_final]]]#angels[arg_fin]
except:
ang_int=0
return ang_int
def do_work_of_slopes_new(
box_text, contour, contour_par, index_r_con,
textline_mask_tot_ea, image_page_rotated, slope_deskew,

View file

@ -0,0 +1,488 @@
import numpy as np
import cv2
import tensorflow as tf
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
import math
from PIL import Image, ImageDraw, ImageFont
from Bio import pairwise2
from .resize import resize_image
def decode_batch_predictions(pred, num_to_char, max_len = 128):
# input_len is the product of the batch size and the
# number of time steps.
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Decode CTC predictions using greedy search.
# decoded is a tuple with 2 elements.
decoded = tf.keras.backend.ctc_decode(pred,
input_length = input_len,
beam_width = 100)
# The outputs are in the first element of the tuple.
# Additionally, the first element is actually a list,
# therefore we take the first element of that list as well.
#print(decoded,'decoded')
decoded = decoded[0][0][:, :max_len]
#print(decoded, decoded.shape,'decoded')
output = []
for d in decoded:
# Convert the predicted indices to the corresponding chars.
d = tf.strings.reduce_join(num_to_char(d))
d = d.numpy().decode("utf-8")
output.append(d)
return output
def distortion_free_resize(image, img_size):
w, h = img_size
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
# Check tha amount of padding needed to be done.
pad_height = h - tf.shape(image)[0]
pad_width = w - tf.shape(image)[1]
# Only necessary if you want to do same amount of padding on both sides.
if pad_height % 2 != 0:
height = pad_height // 2
pad_height_top = height + 1
pad_height_bottom = height
else:
pad_height_top = pad_height_bottom = pad_height // 2
if pad_width % 2 != 0:
width = pad_width // 2
pad_width_left = width + 1
pad_width_right = width
else:
pad_width_left = pad_width_right = pad_width // 2
image = tf.pad(
image,
paddings=[
[pad_height_top, pad_height_bottom],
[pad_width_left, pad_width_right],
[0, 0],
],
)
image = tf.transpose(image, (1, 0, 2))
image = tf.image.flip_left_right(image)
return image
def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.06*width)
width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)
peaks_real, _ = find_peaks(sum_smoothed, height=0)
if len(peaks_real)>70:
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
arg_max = np.argmax(sum_smoothed[peaks_real])
peaks_final = peaks_real[arg_max]
return peaks_final
else:
return None
# Function to fit text inside the given area
def fit_text_single_line(draw, text, font_path, max_width, max_height):
initial_font_size = 50
font_size = initial_font_size
while font_size > 10: # Minimum font size
font = ImageFont.truetype(font_path, font_size)
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
if text_width <= max_width and text_height <= max_height:
return font # Return the best-fitting font
font_size -= 2 # Reduce font size and retry
return ImageFont.truetype(font_path, 10) # Smallest font fallback
def return_textlines_split_if_needed(textline_image, textline_image_bin, prediction_with_both_of_rgb_and_bin=False):
split_point = return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image)
if split_point:
image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height))
image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height))
if prediction_with_both_of_rgb_and_bin:
image1_bin = textline_image_bin[:, :split_point,:]# image.crop((0, 0, width2, height))
image2_bin = textline_image_bin[:, split_point:,:]#image.crop((width1, 0, width, height))
return [image1, image2], [image1_bin, image2_bin]
else:
return [image1, image2], None
else:
return None, None
def preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width):
if img.shape[0]==0 or img.shape[1]==0:
img_fin = np.ones((image_height, image_width, 3))
else:
ratio = image_height /float(img.shape[0])
w_ratio = int(ratio * img.shape[1])
if w_ratio <= image_width:
width_new = w_ratio
else:
width_new = image_width
if width_new == 0:
width_new = img.shape[1]
img = resize_image(img, image_height, width_new)
img_fin = np.ones((image_height, image_width, 3))*255
img_fin[:,:width_new,:] = img[:,:,:]
img_fin = img_fin / 255.
return img_fin
def get_deskewed_contour_and_bb_and_image(contour, image, deskew_angle):
(h_in, w_in) = image.shape[:2]
center = (w_in // 2, h_in // 2)
rotation_matrix = cv2.getRotationMatrix2D(center, deskew_angle, 1.0)
cos_angle = abs(rotation_matrix[0, 0])
sin_angle = abs(rotation_matrix[0, 1])
new_w = int((h_in * sin_angle) + (w_in * cos_angle))
new_h = int((h_in * cos_angle) + (w_in * sin_angle))
rotation_matrix[0, 2] += (new_w / 2) - center[0]
rotation_matrix[1, 2] += (new_h / 2) - center[1]
deskewed_image = cv2.warpAffine(image, rotation_matrix, (new_w, new_h))
contour_points = np.array(contour, dtype=np.float32)
transformed_points = cv2.transform(np.array([contour_points]), rotation_matrix)[0]
x, y, w, h = cv2.boundingRect(np.array(transformed_points, dtype=np.int32))
cropped_textline = deskewed_image[y:y+h, x:x+w]
return cropped_textline
def rotate_image_with_padding(image, angle, border_value=(0,0,0)):
# Get image dimensions
(h, w) = image.shape[:2]
# Calculate the center of the image
center = (w // 2, h // 2)
# Get the rotation matrix
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
# Compute the new bounding dimensions
cos = abs(rotation_matrix[0, 0])
sin = abs(rotation_matrix[0, 1])
new_w = int((h * sin) + (w * cos))
new_h = int((h * cos) + (w * sin))
# Adjust the rotation matrix to account for translation
rotation_matrix[0, 2] += (new_w / 2) - center[0]
rotation_matrix[1, 2] += (new_h / 2) - center[1]
# Perform the rotation
try:
rotated_image = cv2.warpAffine(image, rotation_matrix, (new_w, new_h), borderValue=border_value)
except:
rotated_image = np.copy(image)
return rotated_image
def get_orientation_moments(contour):
moments = cv2.moments(contour)
if moments["mu20"] - moments["mu02"] == 0: # Avoid division by zero
return 90 if moments["mu11"] > 0 else -90
else:
angle = 0.5 * np.arctan2(2 * moments["mu11"], moments["mu20"] - moments["mu02"])
return np.degrees(angle) # Convert radians to degrees
def get_orientation_moments_of_mask(mask):
mask=mask.astype('uint8')
contours, _ = cv2.findContours(mask[:,:,0], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
largest_contour = max(contours, key=cv2.contourArea) if contours else None
moments = cv2.moments(largest_contour)
if moments["mu20"] - moments["mu02"] == 0: # Avoid division by zero
return 90 if moments["mu11"] > 0 else -90
else:
angle = 0.5 * np.arctan2(2 * moments["mu11"], moments["mu20"] - moments["mu02"])
return np.degrees(angle) # Convert radians to degrees
def get_contours_and_bounding_boxes(mask):
# Find contours in the binary mask
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
largest_contour = max(contours, key=cv2.contourArea) if contours else None
# Get the bounding rectangle for the contour
x, y, w, h = cv2.boundingRect(largest_contour)
#bounding_boxes.append((x, y, w, h))
return x, y, w, h
def return_splitting_point_of_image(image_to_spliited):
width = np.shape(image_to_spliited)[1]
height = np.shape(image_to_spliited)[0]
common_window = int(0.03*width)
width1 = int ( common_window)
width2 = int ( width - common_window )
img_sum = np.sum(image_to_spliited[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 1)
peaks_real, _ = find_peaks(sum_smoothed, height=0)
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
arg_sort = np.argsort(sum_smoothed[peaks_real])
peaks_sort_4 = peaks_real[arg_sort][::-1][:3]
return np.sort(peaks_sort_4)
def break_curved_line_into_small_pieces_and_then_merge(img_curved, mask_curved, img_bin_curved=None):
peaks_4 = return_splitting_point_of_image(img_curved)
if len(peaks_4)>0:
imgs_tot = []
for ind in range(len(peaks_4)+1):
if ind==0:
img = img_curved[:, :peaks_4[ind], :]
if img_bin_curved is not None:
img_bin = img_bin_curved[:, :peaks_4[ind], :]
mask = mask_curved[:, :peaks_4[ind], :]
elif ind==len(peaks_4):
img = img_curved[:, peaks_4[ind-1]:, :]
if img_bin_curved is not None:
img_bin = img_bin_curved[:, peaks_4[ind-1]:, :]
mask = mask_curved[:, peaks_4[ind-1]:, :]
else:
img = img_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
if img_bin_curved is not None:
img_bin = img_bin_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
mask = mask_curved[:, peaks_4[ind-1]:peaks_4[ind], :]
or_ma = get_orientation_moments_of_mask(mask)
if img_bin_curved is not None:
imgs_tot.append([img, mask, or_ma, img_bin] )
else:
imgs_tot.append([img, mask, or_ma] )
w_tot_des_list = []
w_tot_des = 0
imgs_deskewed_list = []
imgs_bin_deskewed_list = []
for ind in range(len(imgs_tot)):
img_in = imgs_tot[ind][0]
mask_in = imgs_tot[ind][1]
ori_in = imgs_tot[ind][2]
if img_bin_curved is not None:
img_bin_in = imgs_tot[ind][3]
if abs(ori_in)<45:
img_in_des = rotate_image_with_padding(img_in, ori_in, border_value=(255,255,255) )
if img_bin_curved is not None:
img_bin_in_des = rotate_image_with_padding(img_bin_in, ori_in, border_value=(255,255,255) )
mask_in_des = rotate_image_with_padding(mask_in, ori_in)
mask_in_des = mask_in_des.astype('uint8')
#new bounding box
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_in_des[:,:,0])
if w_n==0 or h_n==0:
img_in_des = np.copy(img_in)
if img_bin_curved is not None:
img_bin_in_des = np.copy(img_bin_in)
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
if w_relative==0:
w_relative = img_in_des.shape[1]
img_in_des = resize_image(img_in_des, 32, w_relative)
if img_bin_curved is not None:
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
else:
mask_in_des = mask_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
img_in_des = img_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
if img_bin_curved is not None:
img_bin_in_des = img_bin_in_des[y_n:y_n+h_n, x_n:x_n+w_n, :]
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
if w_relative==0:
w_relative = img_in_des.shape[1]
img_in_des = resize_image(img_in_des, 32, w_relative)
if img_bin_curved is not None:
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
else:
img_in_des = np.copy(img_in)
if img_bin_curved is not None:
img_bin_in_des = np.copy(img_bin_in)
w_relative = int(32 * img_in_des.shape[1]/float(img_in_des.shape[0]) )
if w_relative==0:
w_relative = img_in_des.shape[1]
img_in_des = resize_image(img_in_des, 32, w_relative)
if img_bin_curved is not None:
img_bin_in_des = resize_image(img_bin_in_des, 32, w_relative)
w_tot_des+=img_in_des.shape[1]
w_tot_des_list.append(img_in_des.shape[1])
imgs_deskewed_list.append(img_in_des)
if img_bin_curved is not None:
imgs_bin_deskewed_list.append(img_bin_in_des)
img_final_deskewed = np.zeros((32, w_tot_des, 3))+255
if img_bin_curved is not None:
img_bin_final_deskewed = np.zeros((32, w_tot_des, 3))+255
else:
img_bin_final_deskewed = None
w_indexer = 0
for ind in range(len(w_tot_des_list)):
img_final_deskewed[:,w_indexer:w_indexer+w_tot_des_list[ind],:] = imgs_deskewed_list[ind][:,:,:]
if img_bin_curved is not None:
img_bin_final_deskewed[:,w_indexer:w_indexer+w_tot_des_list[ind],:] = imgs_bin_deskewed_list[ind][:,:,:]
w_indexer = w_indexer+w_tot_des_list[ind]
return img_final_deskewed, img_bin_final_deskewed
else:
return img_curved, img_bin_curved
def return_textline_contour_with_added_box_coordinate(textline_contour, box_ind):
textline_contour[:,0] = textline_contour[:,0] + box_ind[2]
textline_contour[:,1] = textline_contour[:,1] + box_ind[0]
return textline_contour
def return_rnn_cnn_ocr_of_given_textlines(image, all_found_textline_polygons, prediction_model, b_s_ocr, num_to_char, textline_light=False, curved_line=False):
max_len = 512
padding_token = 299
image_width = 512#max_len * 4
image_height = 32
ind_tot = 0
#cv2.imwrite('./img_out.png', image_page)
ocr_all_textlines = []
cropped_lines_region_indexer = []
cropped_lines_meging_indexing = []
cropped_lines = []
indexer_text_region = 0
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
#ocr_textline_in_textregion = []
if len(ind_poly_first)==0:
cropped_lines_region_indexer.append(indexer_text_region)
cropped_lines_meging_indexing.append(0)
img_fin = np.ones((image_height, image_width, 3))*1
cropped_lines.append(img_fin)
else:
for indexing2, ind_poly in enumerate(ind_poly_first):
cropped_lines_region_indexer.append(indexer_text_region)
if not (textline_light or curved_line):
ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing]
ind_poly = return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
#print(ind_poly_copy)
ind_poly[ind_poly<0] = 0
x, y, w, h = cv2.boundingRect(ind_poly)
w_scaled = w * image_height/float(h)
mask_poly = np.zeros(image.shape)
img_poly_on_img = np.copy(image)
mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))
mask_poly = mask_poly[y:y+h, x:x+w, :]
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
img_crop[mask_poly==0] = 255
if w_scaled < 640:#1.5*image_width:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(0)
else:
splited_images, splited_images_bin = return_textlines_split_if_needed(img_crop, None)
if splited_images:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(splited_images[0], image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(1)
img_fin = preprocess_and_resize_image_for_ocrcnn_model(splited_images[1], image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(-1)
else:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width)
cropped_lines.append(img_fin)
cropped_lines_meging_indexing.append(0)
indexer_text_region+=1
extracted_texts = []
n_iterations = math.ceil(len(cropped_lines) / b_s_ocr)
for i in range(n_iterations):
if i==(n_iterations-1):
n_start = i*b_s_ocr
imgs = cropped_lines[n_start:]
imgs = np.array(imgs)
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
else:
n_start = i*b_s_ocr
n_end = (i+1)*b_s_ocr
imgs = cropped_lines[n_start:n_end]
imgs = np.array(imgs).reshape(b_s_ocr, image_height, image_width, 3)
preds = prediction_model.predict(imgs, verbose=0)
pred_texts = decode_batch_predictions(preds, num_to_char)
for ib in range(imgs.shape[0]):
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
extracted_texts.append(pred_texts_ib)
extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+" "+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
ocr_all_textlines = []
for ind in unique_cropped_lines_region_indexer:
ocr_textline_in_textregion = []
extracted_texts_merged_un = np.array(extracted_texts_merged)[np.array(cropped_lines_region_indexer)==ind]
for it_ind, text_textline in enumerate(extracted_texts_merged_un):
ocr_textline_in_textregion.append(text_textline)
ocr_all_textlines.append(ocr_textline_in_textregion)
return ocr_all_textlines
def biopython_align(str1, str2):
alignments = pairwise2.align.globalms(str1, str2, 2, -1, -2, -2)
best_alignment = alignments[0] # Get the best alignment
return best_alignment.seqA, best_alignment.seqB

View file

@ -46,16 +46,22 @@ def create_page_xml(imageFilename, height, width):
))
return pcgts
def xml_reading_order(page, order_of_texts, id_of_marginalia):
def xml_reading_order(page, order_of_texts, id_of_marginalia_left, id_of_marginalia_right):
region_order = ReadingOrderType()
og = OrderedGroupType(id="ro357564684568544579089")
page.set_ReadingOrder(region_order)
region_order.set_OrderedGroup(og)
region_counter = EynollahIdCounter()
for id_marginal in id_of_marginalia_left:
og.add_RegionRefIndexed(RegionRefIndexedType(index=str(region_counter.get('region')), regionRef=id_marginal))
region_counter.inc('region')
for idx_textregion, _ in enumerate(order_of_texts):
og.add_RegionRefIndexed(RegionRefIndexedType(index=str(region_counter.get('region')), regionRef=region_counter.region_id(order_of_texts[idx_textregion] + 1)))
region_counter.inc('region')
for id_marginal in id_of_marginalia:
for id_marginal in id_of_marginalia_right:
og.add_RegionRefIndexed(RegionRefIndexedType(index=str(region_counter.get('region')), regionRef=id_marginal))
region_counter.inc('region')

View file

@ -56,10 +56,12 @@ class EynollahXmlWriter():
points_page_print = points_page_print + ' '
return points_page_print[:-1]
def serialize_lines_in_marginal(self, marginal_region, all_found_textline_polygons_marginals, marginal_idx, page_coord, all_box_coord_marginals, slopes_marginals, counter):
def serialize_lines_in_marginal(self, marginal_region, all_found_textline_polygons_marginals, marginal_idx, page_coord, all_box_coord_marginals, slopes_marginals, counter, ocr_all_textlines_textregion):
for j in range(len(all_found_textline_polygons_marginals[marginal_idx])):
coords = CoordsType()
textline = TextLineType(id=counter.next_line_id, Coords=coords)
if ocr_all_textlines_textregion:
textline.set_TextEquiv( [ TextEquivType(Unicode=ocr_all_textlines_textregion[j]) ] )
marginal_region.add_TextLine(textline)
marginal_region.set_orientation(-slopes_marginals[marginal_idx])
points_co = ''
@ -119,7 +121,7 @@ class EynollahXmlWriter():
points_co += ','
points_co += str(textline_y_coord)
if (self.curved_line or self.textline_light) and np.abs(slopes[region_idx]) <= 45:
if self.textline_light or (self.curved_line and np.abs(slopes[region_idx]) <= 45):
if len(contour_textline) == 2:
points_co += str(int((contour_textline[0] + page_coord[2]) / self.scale_x))
points_co += ','
@ -128,7 +130,7 @@ class EynollahXmlWriter():
points_co += str(int((contour_textline[0][0] + page_coord[2]) / self.scale_x))
points_co += ','
points_co += str(int((contour_textline[0][1] + page_coord[0])/self.scale_y))
elif (self.curved_line or self.textline_light) and np.abs(slopes[region_idx]) > 45:
elif self.curved_line and np.abs(slopes[region_idx]) > 45:
if len(contour_textline)==2:
points_co += str(int((contour_textline[0] + region_bboxes[2] + page_coord[2])/self.scale_x))
points_co += ','
@ -168,7 +170,7 @@ class EynollahXmlWriter():
with open(self.output_filename, 'w') as f:
f.write(to_xml(pcgts))
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines, conf_contours_textregion):
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals_left, found_polygons_marginals_right, all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, all_box_coord_marginals_left, all_box_coord_marginals_right, slopes, slopes_marginals_left, slopes_marginals_right, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines=None, ocr_all_textlines_marginals_left=None, ocr_all_textlines_marginals_right=None, conf_contours_textregion=None, skip_layout_reading_order=False):
self.logger.debug('enter build_pagexml_no_full_layout')
# create the file structure
@ -179,12 +181,13 @@ class EynollahXmlWriter():
counter = EynollahIdCounter()
if len(found_polygons_text_region) > 0:
_counter_marginals = EynollahIdCounter(region_idx=len(order_of_texts))
id_of_marginalia = [_counter_marginals.next_region_id for _ in found_polygons_marginals]
xml_reading_order(page, order_of_texts, id_of_marginalia)
id_of_marginalia_left = [_counter_marginals.next_region_id for _ in found_polygons_marginals_left]
id_of_marginalia_right = [_counter_marginals.next_region_id for _ in found_polygons_marginals_right]
xml_reading_order(page, order_of_texts, id_of_marginalia_left, id_of_marginalia_right)
for mm in range(len(found_polygons_text_region)):
textregion = TextRegionType(id=counter.next_region_id, type_='paragraph',
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord), conf=conf_contours_textregion[mm]),
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord, skip_layout_reading_order), conf=conf_contours_textregion[mm]),
)
#textregion.set_conf(conf_contours_textregion[mm])
page.add_TextRegion(textregion)
@ -193,12 +196,29 @@ class EynollahXmlWriter():
else:
ocr_textlines = None
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter, ocr_textlines)
for mm in range(len(found_polygons_marginals)):
for mm in range(len(found_polygons_marginals_left)):
marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_marginals[mm], page_coord)))
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_marginals_left[mm], page_coord)))
page.add_TextRegion(marginal)
self.serialize_lines_in_marginal(marginal, all_found_textline_polygons_marginals, mm, page_coord, all_box_coord_marginals, slopes_marginals, counter)
if ocr_all_textlines_marginals_left:
ocr_textlines = ocr_all_textlines_marginals_left[mm]
else:
ocr_textlines = None
#print(ocr_textlines, mm, len(all_found_textline_polygons_marginals_left[mm]) )
self.serialize_lines_in_marginal(marginal, all_found_textline_polygons_marginals_left, mm, page_coord, all_box_coord_marginals_left, slopes_marginals_left, counter, ocr_textlines)
for mm in range(len(found_polygons_marginals_right)):
marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_marginals_right[mm], page_coord)))
page.add_TextRegion(marginal)
if ocr_all_textlines_marginals_right:
ocr_textlines = ocr_all_textlines_marginals_right[mm]
else:
ocr_textlines = None
self.serialize_lines_in_marginal(marginal, all_found_textline_polygons_marginals_right, mm, page_coord, all_box_coord_marginals_right, slopes_marginals_right, counter, ocr_textlines)
for mm in range(len(found_polygons_text_region_img)):
img_region = ImageRegionType(id=counter.next_region_id, Coords=CoordsType())
@ -242,7 +262,7 @@ class EynollahXmlWriter():
return pcgts
def build_pagexml_full_layout(self, found_polygons_text_region, found_polygons_text_region_h, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, ocr_all_textlines, conf_contours_textregion, conf_contours_textregion_h):
def build_pagexml_full_layout(self, found_polygons_text_region, found_polygons_text_region_h, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, found_polygons_marginals_left,found_polygons_marginals_right, all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, all_box_coord_marginals_left, all_box_coord_marginals_right, slopes, slopes_h, slopes_marginals_left, slopes_marginals_right, cont_page, polygons_lines_to_be_written_in_xml, ocr_all_textlines=None, ocr_all_textlines_h=None, ocr_all_textlines_marginals_left=None, ocr_all_textlines_marginals_right=None, ocr_all_textlines_drop=None, conf_contours_textregion=None, conf_contours_textregion_h=None):
self.logger.debug('enter build_pagexml_full_layout')
# create the file structure
@ -252,8 +272,9 @@ class EynollahXmlWriter():
counter = EynollahIdCounter()
_counter_marginals = EynollahIdCounter(region_idx=len(order_of_texts))
id_of_marginalia = [_counter_marginals.next_region_id for _ in found_polygons_marginals]
xml_reading_order(page, order_of_texts, id_of_marginalia)
id_of_marginalia_left = [_counter_marginals.next_region_id for _ in found_polygons_marginals_left]
id_of_marginalia_right = [_counter_marginals.next_region_id for _ in found_polygons_marginals_right]
xml_reading_order(page, order_of_texts, id_of_marginalia_left, id_of_marginalia_right)
for mm in range(len(found_polygons_text_region)):
textregion = TextRegionType(id=counter.next_region_id, type_='paragraph',
@ -272,25 +293,43 @@ class EynollahXmlWriter():
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region_h[mm], page_coord)))
page.add_TextRegion(textregion)
if ocr_all_textlines:
ocr_textlines = ocr_all_textlines[mm]
if ocr_all_textlines_h:
ocr_textlines = ocr_all_textlines_h[mm]
else:
ocr_textlines = None
self.serialize_lines_in_region(textregion, all_found_textline_polygons_h, mm, page_coord, all_box_coord_h, slopes_h, counter, ocr_textlines)
for mm in range(len(found_polygons_marginals)):
for mm in range(len(found_polygons_marginals_left)):
marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_marginals[mm], page_coord)))
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_marginals_left[mm], page_coord)))
page.add_TextRegion(marginal)
self.serialize_lines_in_marginal(marginal, all_found_textline_polygons_marginals, mm, page_coord, all_box_coord_marginals, slopes_marginals, counter)
if ocr_all_textlines_marginals_left:
ocr_textlines = ocr_all_textlines_marginals_left[mm]
else:
ocr_textlines = None
self.serialize_lines_in_marginal(marginal, all_found_textline_polygons_marginals_left, mm, page_coord, all_box_coord_marginals_left, slopes_marginals_left, counter, ocr_textlines)
for mm in range(len(found_polygons_marginals_right)):
marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_marginals_right[mm], page_coord)))
page.add_TextRegion(marginal)
if ocr_all_textlines_marginals_right:
ocr_textlines = ocr_all_textlines_marginals_right[mm]
else:
ocr_textlines = None
self.serialize_lines_in_marginal(marginal, all_found_textline_polygons_marginals_right, mm, page_coord, all_box_coord_marginals_right, slopes_marginals_right, counter, ocr_textlines)
for mm in range(len(found_polygons_drop_capitals)):
dropcapital = TextRegionType(id=counter.next_region_id, type_='drop-capital',
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_drop_capitals[mm], page_coord)))
page.add_TextRegion(dropcapital)
###all_box_coord_drop = None
###slopes_drop = None
###self.serialize_lines_in_dropcapital(dropcapital, [found_polygons_drop_capitals[mm]], mm, page_coord, all_box_coord_drop, slopes_drop, counter, ocr_all_textlines_textregion=None)
all_box_coord_drop = None
slopes_drop = None
if ocr_all_textlines_drop:
ocr_textlines = ocr_all_textlines_drop[mm]
else:
ocr_textlines = None
self.serialize_lines_in_dropcapital(dropcapital, [found_polygons_drop_capitals[mm]], mm, page_coord, all_box_coord_drop, slopes_drop, counter, ocr_all_textlines_textregion=ocr_textlines)
for mm in range(len(found_polygons_text_region_img)):
page.add_ImageRegion(ImageRegionType(id=counter.next_region_id, Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region_img[mm], page_coord))))
@ -303,18 +342,28 @@ class EynollahXmlWriter():
return pcgts
def calculate_polygon_coords(self, contour, page_coord):
def calculate_polygon_coords(self, contour, page_coord, skip_layout_reading_order=False):
self.logger.debug('enter calculate_polygon_coords')
coords = ''
for value_bbox in contour:
if len(value_bbox) == 2:
coords += str(int((value_bbox[0] + page_coord[2]) / self.scale_x))
coords += ','
coords += str(int((value_bbox[1] + page_coord[0]) / self.scale_y))
if skip_layout_reading_order:
if len(value_bbox) == 2:
coords += str(int((value_bbox[0]) / self.scale_x))
coords += ','
coords += str(int((value_bbox[1]) / self.scale_y))
else:
coords += str(int((value_bbox[0][0]) / self.scale_x))
coords += ','
coords += str(int((value_bbox[0][1]) / self.scale_y))
else:
coords += str(int((value_bbox[0][0] + page_coord[2]) / self.scale_x))
coords += ','
coords += str(int((value_bbox[0][1] + page_coord[0]) / self.scale_y))
if len(value_bbox) == 2:
coords += str(int((value_bbox[0] + page_coord[2]) / self.scale_x))
coords += ','
coords += str(int((value_bbox[1] + page_coord[0]) / self.scale_y))
else:
coords += str(int((value_bbox[0][0] + page_coord[2]) / self.scale_x))
coords += ','
coords += str(int((value_bbox[0][1] + page_coord[0]) / self.scale_y))
coords=coords + ' '
return coords[:-1]