🔥 drop light_version/textline_light (now default and implied)

This commit is contained in:
kba 2025-11-26 20:29:29 +01:00
parent ca83cf934d
commit 83e8b289da
16 changed files with 183 additions and 732 deletions

View file

@ -103,8 +103,6 @@ The following options can be used to further configure the processing:
| option | description | | option | description |
|-------------------|:--------------------------------------------------------------------------------------------| |-------------------|:--------------------------------------------------------------------------------------------|
| `-fl` | full layout analysis including all steps and segmentation classes (recommended) | | `-fl` | full layout analysis including all steps and segmentation classes (recommended) |
| `-light` | lighter and faster but simpler method for main region detection and deskewing (recommended) |
| `-tll` | this indicates the light textline and should be passed with light version (recommended) |
| `-tab` | apply table detection | | `-tab` | apply table detection |
| `-ae` | apply enhancement (the resulting image is saved to the output directory) | | `-ae` | apply enhancement (the resulting image is saved to the output directory) |
| `-as` | apply scaling | | `-as` | apply scaling |

View file

@ -81,12 +81,6 @@ import click
is_flag=True, is_flag=True,
help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline. This should be taken into account that with this option the tool need more time to do process.", help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline. This should be taken into account that with this option the tool need more time to do process.",
) )
@click.option(
"--textline_light/--no-textline_light",
"-tll/-notll",
is_flag=True,
help="if this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline with a faster method.",
)
@click.option( @click.option(
"--full-layout/--no-full-layout", "--full-layout/--no-full-layout",
"-fl/-nofl", "-fl/-nofl",
@ -123,12 +117,6 @@ import click
is_flag=True, is_flag=True,
help="if this parameter set to true, this tool would ignore headers role in reading order", help="if this parameter set to true, this tool would ignore headers role in reading order",
) )
@click.option(
"--light_version/--original",
"-light/-org",
is_flag=True,
help="if this parameter set to true, this tool would use lighter version",
)
@click.option( @click.option(
"--ignore_page_extraction/--extract_page_included", "--ignore_page_extraction/--extract_page_included",
"-ipe/-epi", "-ipe/-epi",
@ -183,14 +171,12 @@ def layout_cli(
enable_plotting, enable_plotting,
allow_enhancement, allow_enhancement,
curved_line, curved_line,
textline_light,
full_layout, full_layout,
tables, tables,
right2left, right2left,
input_binary, input_binary,
allow_scaling, allow_scaling,
headers_off, headers_off,
light_version,
reading_order_machine_based, reading_order_machine_based,
num_col_upper, num_col_upper,
num_col_lower, num_col_lower,
@ -211,12 +197,9 @@ def layout_cli(
assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep" assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep"
assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \ assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \
"Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae" "Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae"
assert textline_light == light_version, "Both light textline detection -tll and light version -light must be set or unset equally"
assert not extract_only_images or not allow_enhancement, "Image extraction -eoi can not be set alongside allow_enhancement -ae" assert not extract_only_images or not allow_enhancement, "Image extraction -eoi can not be set alongside allow_enhancement -ae"
assert not extract_only_images or not allow_scaling, "Image extraction -eoi can not be set alongside allow_scaling -as" assert not extract_only_images or not allow_scaling, "Image extraction -eoi can not be set alongside allow_scaling -as"
assert not extract_only_images or not light_version, "Image extraction -eoi can not be set alongside light_version -light"
assert not extract_only_images or not curved_line, "Image extraction -eoi can not be set alongside curved_line -cl" assert not extract_only_images or not curved_line, "Image extraction -eoi can not be set alongside curved_line -cl"
assert not extract_only_images or not textline_light, "Image extraction -eoi can not be set alongside textline_light -tll"
assert not extract_only_images or not full_layout, "Image extraction -eoi can not be set alongside full_layout -fl" assert not extract_only_images or not full_layout, "Image extraction -eoi can not be set alongside full_layout -fl"
assert not extract_only_images or not tables, "Image extraction -eoi can not be set alongside tables -tab" assert not extract_only_images or not tables, "Image extraction -eoi can not be set alongside tables -tab"
assert not extract_only_images or not right2left, "Image extraction -eoi can not be set alongside right2left -r2l" assert not extract_only_images or not right2left, "Image extraction -eoi can not be set alongside right2left -r2l"
@ -228,14 +211,12 @@ def layout_cli(
enable_plotting=enable_plotting, enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement, allow_enhancement=allow_enhancement,
curved_line=curved_line, curved_line=curved_line,
textline_light=textline_light,
full_layout=full_layout, full_layout=full_layout,
tables=tables, tables=tables,
right2left=right2left, right2left=right2left,
input_binary=input_binary, input_binary=input_binary,
allow_scaling=allow_scaling, allow_scaling=allow_scaling,
headers_off=headers_off, headers_off=headers_off,
light_version=light_version,
ignore_page_extraction=ignore_page_extraction, ignore_page_extraction=ignore_page_extraction,
reading_order_machine_based=reading_order_machine_based, reading_order_machine_based=reading_order_machine_based,
num_col_upper=num_col_upper, num_col_upper=num_col_upper,

View file

@ -36,7 +36,6 @@ from functools import partial
from pathlib import Path from pathlib import Path
from multiprocessing import cpu_count from multiprocessing import cpu_count
import gc import gc
import copy
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
import cv2 import cv2
@ -51,13 +50,6 @@ import statistics
tf_disable_interactive_logs() tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
# warnings.filterwarnings("ignore")
from tensorflow.python.keras import backend as K
from tensorflow.keras.models import load_model
# use tf1 compatibility for keras backend
from tensorflow.compat.v1.keras.backend import set_session
from tensorflow.keras import layers
from tensorflow.keras.layers import StringLookup
try: try:
import torch import torch
except ImportError: except ImportError:
@ -71,16 +63,13 @@ from .model_zoo import EynollahModelZoo
from .utils.contour import ( from .utils.contour import (
filter_contours_area_of_image, filter_contours_area_of_image,
filter_contours_area_of_image_tables, filter_contours_area_of_image_tables,
find_contours_mean_y_diff,
find_center_of_contours, find_center_of_contours,
find_new_features_of_contours, find_new_features_of_contours,
find_features_of_contours, find_features_of_contours,
get_text_region_boxes_by_given_contours, get_text_region_boxes_by_given_contours,
get_textregion_contours_in_org_image,
get_textregion_contours_in_org_image_light, get_textregion_contours_in_org_image_light,
return_contours_of_image, return_contours_of_image,
return_contours_of_interested_region, return_contours_of_interested_region,
return_contours_of_interested_textline,
return_parent_contours, return_parent_contours,
dilate_textregion_contours, dilate_textregion_contours,
dilate_textline_contours, dilate_textline_contours,
@ -93,40 +82,30 @@ from .utils.rotate import (
rotate_image, rotate_image,
rotation_not_90_func, rotation_not_90_func,
rotation_not_90_func_full_layout, rotation_not_90_func_full_layout,
rotation_image_new
) )
from .utils.separate_lines import ( from .utils.separate_lines import (
separate_lines_new2,
return_deskew_slop, return_deskew_slop,
do_work_of_slopes_new, do_work_of_slopes_new,
do_work_of_slopes_new_curved, do_work_of_slopes_new_curved,
do_work_of_slopes_new_light, do_work_of_slopes_new_light,
) )
from .utils.drop_capitals import (
adhere_drop_capital_region_into_corresponding_textline,
filter_small_drop_capitals_from_no_patch_layout
)
from .utils.marginals import get_marginals from .utils.marginals import get_marginals
from .utils.resize import resize_image from .utils.resize import resize_image
from .utils.shm import share_ndarray from .utils.shm import share_ndarray
from .utils import ( from .utils import (
is_image_filename, is_image_filename,
boosting_headers_by_longshot_region_segmentation,
crop_image_inside_box, crop_image_inside_box,
box2rect, box2rect,
box2slice,
find_num_col, find_num_col,
otsu_copy_binary, otsu_copy_binary,
put_drop_out_from_only_drop_model,
putt_bb_of_drop_capitals_of_model_in_patches_in_layout, putt_bb_of_drop_capitals_of_model_in_patches_in_layout,
check_any_text_region_in_model_one_is_main_or_header,
check_any_text_region_in_model_one_is_main_or_header_light, check_any_text_region_in_model_one_is_main_or_header_light,
small_textlines_to_parent_adherence2, small_textlines_to_parent_adherence2,
order_of_regions, order_of_regions,
find_number_of_columns_in_document, find_number_of_columns_in_document,
return_boxes_of_images_by_order_of_reading_new return_boxes_of_images_by_order_of_reading_new
) )
from .utils.pil_cv2 import check_dpi, pil2cv from .utils.pil_cv2 import pil2cv
from .utils.xml import order_and_id_of_texts from .utils.xml import order_and_id_of_texts
from .plot import EynollahPlotter from .plot import EynollahPlotter
from .writer import EynollahXmlWriter from .writer import EynollahXmlWriter
@ -153,14 +132,12 @@ class Eynollah:
enable_plotting : bool = False, enable_plotting : bool = False,
allow_enhancement : bool = False, allow_enhancement : bool = False,
curved_line : bool = False, curved_line : bool = False,
textline_light : bool = False,
full_layout : bool = False, full_layout : bool = False,
tables : bool = False, tables : bool = False,
right2left : bool = False, right2left : bool = False,
input_binary : bool = False, input_binary : bool = False,
allow_scaling : bool = False, allow_scaling : bool = False,
headers_off : bool = False, headers_off : bool = False,
light_version : bool = False,
ignore_page_extraction : bool = False, ignore_page_extraction : bool = False,
reading_order_machine_based : bool = False, reading_order_machine_based : bool = False,
num_col_upper : Optional[int] = None, num_col_upper : Optional[int] = None,
@ -174,14 +151,10 @@ class Eynollah:
self.model_zoo = model_zoo self.model_zoo = model_zoo
self.plotter = None self.plotter = None
if skip_layout_and_reading_order:
textline_light = True
self.light_version = light_version
self.reading_order_machine_based = reading_order_machine_based self.reading_order_machine_based = reading_order_machine_based
self.enable_plotting = enable_plotting self.enable_plotting = enable_plotting
self.allow_enhancement = allow_enhancement self.allow_enhancement = allow_enhancement
self.curved_line = curved_line self.curved_line = curved_line
self.textline_light = textline_light
self.full_layout = full_layout self.full_layout = full_layout
self.tables = tables self.tables = tables
self.right2left = right2left self.right2left = right2left
@ -189,7 +162,6 @@ class Eynollah:
self.input_binary = input_binary self.input_binary = input_binary
self.allow_scaling = allow_scaling self.allow_scaling = allow_scaling
self.headers_off = headers_off self.headers_off = headers_off
self.light_version = light_version
self.extract_only_images = extract_only_images self.extract_only_images = extract_only_images
self.ignore_page_extraction = ignore_page_extraction self.ignore_page_extraction = ignore_page_extraction
self.skip_layout_and_reading_order = skip_layout_and_reading_order self.skip_layout_and_reading_order = skip_layout_and_reading_order
@ -244,23 +216,18 @@ class Eynollah:
"col_classifier", "col_classifier",
"binarization", "binarization",
"page", "page",
("region", 'extract_only_images' if self.extract_only_images else 'light' if self.light_version else '') ("region", 'extract_only_images' if self.extract_only_images else '')
] ]
if not self.extract_only_images: if not self.extract_only_images:
loadable.append(("textline", 'light' if self.light_version else '')) loadable.append(("textline"))
if self.light_version: loadable.append("region_1_2")
loadable.append("region_1_2")
else:
loadable.append("region_p2")
# if self.allow_enhancement:?
loadable.append("enhancement")
if self.full_layout: if self.full_layout:
loadable.append("region_fl_np") loadable.append("region_fl_np")
#loadable.append("region_fl") #loadable.append("region_fl")
if self.reading_order_machine_based: if self.reading_order_machine_based:
loadable.append("reading_order") loadable.append("reading_order")
if self.tables: if self.tables:
loadable.append(("table", 'light' if self.light_version else '')) loadable.append(("table"))
self.model_zoo.load_models(*loadable) self.model_zoo.load_models(*loadable)
@ -286,16 +253,10 @@ class Eynollah:
t_c0 = time.time() t_c0 = time.time()
if image_filename: if image_filename:
ret['img'] = cv2.imread(image_filename) ret['img'] = cv2.imread(image_filename)
if self.light_version: self.dpi = 100
self.dpi = 100
else:
self.dpi = check_dpi(image_filename)
else: else:
ret['img'] = pil2cv(image_pil) ret['img'] = pil2cv(image_pil)
if self.light_version: self.dpi = 100
self.dpi = 100
else:
self.dpi = check_dpi(image_pil)
ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY) ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY)
for prefix in ('', '_grayscale'): for prefix in ('', '_grayscale'):
ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8) ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8)
@ -309,8 +270,7 @@ class Eynollah:
self.writer = EynollahXmlWriter( self.writer = EynollahXmlWriter(
dir_out=dir_out, dir_out=dir_out,
image_filename=image_filename, image_filename=image_filename,
curved_line=self.curved_line, curved_line=self.curved_line)
textline_light = self.textline_light)
def imread(self, grayscale=False, uint8=True): def imread(self, grayscale=False, uint8=True):
key = 'img' key = 'img'
@ -555,7 +515,7 @@ class Eynollah:
return img, img_new, is_image_enhanced return img, img_new, is_image_enhanced
def resize_and_enhance_image_with_column_classifier(self, light_version): def resize_and_enhance_image_with_column_classifier(self):
self.logger.debug("enter resize_and_enhance_image_with_column_classifier") self.logger.debug("enter resize_and_enhance_image_with_column_classifier")
dpi = self.dpi dpi = self.dpi
self.logger.info("Detected %s DPI", dpi) self.logger.info("Detected %s DPI", dpi)
@ -638,19 +598,16 @@ class Eynollah:
self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5)) self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))
if not self.extract_only_images: if not self.extract_only_images:
if dpi < DPI_THRESHOLD: if dpi < DPI_THRESHOLD:
if light_version and num_col in (1,2): if num_col in (1,2):
img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2(
img, num_col, width_early, label_p_pred) img, num_col, width_early, label_p_pred)
else: else:
img_new, num_column_is_classified = self.calculate_width_height_by_columns( img_new, num_column_is_classified = self.calculate_width_height_by_columns(
img, num_col, width_early, label_p_pred) img, num_col, width_early, label_p_pred)
if light_version: image_res = np.copy(img_new)
image_res = np.copy(img_new)
else:
image_res = self.predict_enhancement(img_new)
is_image_enhanced = True is_image_enhanced = True
else: else:
if light_version and num_col in (1,2): if num_col in (1,2):
img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2(
img, num_col, width_early, label_p_pred) img, num_col, width_early, label_p_pred)
image_res = np.copy(img_new) image_res = np.copy(img_new)
@ -1550,9 +1507,8 @@ class Eynollah:
img_width_h = img.shape[1] img_width_h = img.shape[1]
model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np")
if self.light_version: thresholding_for_fl_light_version = True
thresholding_for_fl_light_version = True if not patches:
elif not patches:
img = otsu_copy_binary(img).astype(np.uint8) img = otsu_copy_binary(img).astype(np.uint8)
prediction_regions = None prediction_regions = None
thresholding_for_fl_light_version = False thresholding_for_fl_light_version = False
@ -1747,7 +1703,6 @@ class Eynollah:
results = self.executor.map(partial(do_work_of_slopes_new_light, results = self.executor.map(partial(do_work_of_slopes_new_light,
textline_mask_tot_ea=textline_mask_tot_shared, textline_mask_tot_ea=textline_mask_tot_shared,
slope_deskew=slope_deskew, slope_deskew=slope_deskew,
textline_light=self.textline_light,
logger=self.logger,), logger=self.logger,),
boxes, contours, contours_par) boxes, contours, contours_par)
results = list(results) # exhaust prior to release results = list(results) # exhaust prior to release
@ -1810,78 +1765,17 @@ class Eynollah:
prediction_textline = self.do_prediction(use_patches, img, self.model_zoo.get("textline"), prediction_textline = self.do_prediction(use_patches, img, self.model_zoo.get("textline"),
marginal_of_patch_percent=0.15, marginal_of_patch_percent=0.15,
n_batch_inference=3, n_batch_inference=3,
thresholding_for_artificial_class_in_light_version=self.textline_light,
threshold_art_class_textline=self.threshold_art_class_textline) threshold_art_class_textline=self.threshold_art_class_textline)
#if not self.textline_light:
#if num_col_classifier==1:
#prediction_textline_nopatch = self.do_prediction(False, img, self.model_zoo.get_model("textline"))
#prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0
prediction_textline = resize_image(prediction_textline, img_h, img_w) prediction_textline = resize_image(prediction_textline, img_h, img_w)
textline_mask_tot_ea_art = (prediction_textline[:,:]==2)*1 textline_mask_tot_ea_art = (prediction_textline[:,:]==2)*1
old_art = np.copy(textline_mask_tot_ea_art) old_art = np.copy(textline_mask_tot_ea_art)
if not self.textline_light:
textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8')
#textline_mask_tot_ea_art = cv2.dilate(textline_mask_tot_ea_art, KERNEL, iterations=1)
prediction_textline[:,:][textline_mask_tot_ea_art[:,:]==1]=2
"""
else:
textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8')
hor_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (8, 1))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
##cv2.imwrite('textline_mask_tot_ea_art.png', textline_mask_tot_ea_art)
textline_mask_tot_ea_art = cv2.dilate(textline_mask_tot_ea_art, hor_kernel, iterations=1)
###cv2.imwrite('dil_textline_mask_tot_ea_art.png', dil_textline_mask_tot_ea_art)
textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8')
#print(np.shape(dil_textline_mask_tot_ea_art), np.unique(dil_textline_mask_tot_ea_art), 'dil_textline_mask_tot_ea_art')
tsk = time.time()
skeleton_art_textline = skeletonize(textline_mask_tot_ea_art[:,:,0])
skeleton_art_textline = skeleton_art_textline*1
skeleton_art_textline = skeleton_art_textline.astype('uint8')
skeleton_art_textline = cv2.dilate(skeleton_art_textline, kernel, iterations=1)
#print(np.unique(skeleton_art_textline), np.shape(skeleton_art_textline))
#print(skeleton_art_textline, np.unique(skeleton_art_textline))
#cv2.imwrite('skeleton_art_textline.png', skeleton_art_textline)
prediction_textline[:,:,0][skeleton_art_textline[:,:]==1]=2
#cv2.imwrite('prediction_textline1.png', prediction_textline[:,:,0])
##hor_kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (4, 1))
##ver_kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 3))
##textline_mask_tot_ea_main = (prediction_textline[:,:]==1)*1
##textline_mask_tot_ea_main = textline_mask_tot_ea_main.astype('uint8')
##dil_textline_mask_tot_ea_main = cv2.erode(textline_mask_tot_ea_main, ver_kernel2, iterations=1)
##dil_textline_mask_tot_ea_main = cv2.dilate(textline_mask_tot_ea_main, hor_kernel2, iterations=1)
##dil_textline_mask_tot_ea_main = cv2.dilate(textline_mask_tot_ea_main, ver_kernel2, iterations=1)
##prediction_textline[:,:][dil_textline_mask_tot_ea_main[:,:]==1]=1
"""
textline_mask_tot_ea_lines = (prediction_textline[:,:]==1)*1 textline_mask_tot_ea_lines = (prediction_textline[:,:]==1)*1
textline_mask_tot_ea_lines = textline_mask_tot_ea_lines.astype('uint8') textline_mask_tot_ea_lines = textline_mask_tot_ea_lines.astype('uint8')
if not self.textline_light:
textline_mask_tot_ea_lines = cv2.dilate(textline_mask_tot_ea_lines, KERNEL, iterations=1)
prediction_textline[:,:][textline_mask_tot_ea_lines[:,:]==1]=1 prediction_textline[:,:][textline_mask_tot_ea_lines[:,:]==1]=1
if not self.textline_light:
prediction_textline[:,:][old_art[:,:]==1]=2
#cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0]) #cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0])
@ -2649,92 +2543,9 @@ class Eynollah:
img_height_h = img_org.shape[0] img_height_h = img_org.shape[0]
img_width_h = img_org.shape[1] img_width_h = img_org.shape[1]
patches = False patches = False
if self.light_version: prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_zoo.get("table"))
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_zoo.get("table")) prediction_table = prediction_table.astype(np.int16)
prediction_table = prediction_table.astype(np.int16) return prediction_table[:,:,0]
return prediction_table[:,:,0]
else:
if num_col_classifier < 4 and num_col_classifier > 2:
prediction_table = self.do_prediction(patches, img, self.model_zoo.get("table"))
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_zoo.get("table"))
pre_updown = cv2.flip(pre_updown, -1)
prediction_table[:,:,0][pre_updown[:,:,0]==1]=1
prediction_table = prediction_table.astype(np.int16)
elif num_col_classifier ==2:
height_ext = 0 # img.shape[0] // 4
h_start = height_ext // 2
width_ext = img.shape[1] // 8
w_start = width_ext // 2
img_new = np.zeros((img.shape[0] + height_ext,
img.shape[1] + width_ext,
img.shape[2])).astype(float)
ys = slice(h_start, h_start + img.shape[0])
xs = slice(w_start, w_start + img.shape[1])
img_new[ys, xs] = img
prediction_ext = self.do_prediction(patches, img_new, self.model_zoo.get("table"))
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_zoo.get("table"))
pre_updown = cv2.flip(pre_updown, -1)
prediction_table = prediction_ext[ys, xs]
prediction_table_updown = pre_updown[ys, xs]
prediction_table[:,:,0][prediction_table_updown[:,:,0]==1]=1
prediction_table = prediction_table.astype(np.int16)
elif num_col_classifier ==1:
height_ext = 0 # img.shape[0] // 4
h_start = height_ext // 2
width_ext = img.shape[1] // 4
w_start = width_ext // 2
img_new =np.zeros((img.shape[0] + height_ext,
img.shape[1] + width_ext,
img.shape[2])).astype(float)
ys = slice(h_start, h_start + img.shape[0])
xs = slice(w_start, w_start + img.shape[1])
img_new[ys, xs] = img
prediction_ext = self.do_prediction(patches, img_new, self.model_zoo.get("table"))
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_zoo.get("table"))
pre_updown = cv2.flip(pre_updown, -1)
prediction_table = prediction_ext[ys, xs]
prediction_table_updown = pre_updown[ys, xs]
prediction_table[:,:,0][prediction_table_updown[:,:,0]==1]=1
prediction_table = prediction_table.astype(np.int16)
else:
prediction_table = np.zeros(img.shape)
img_w_half = img.shape[1] // 2
pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.model_zoo.get("table"))
pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.model_zoo.get("table"))
pre_full = self.do_prediction(patches, img[:,:,:], self.model_zoo.get("table"))
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_zoo.get("table"))
pre_updown = cv2.flip(pre_updown, -1)
prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4)
prediction_table_full_erode = cv2.dilate(prediction_table_full_erode, KERNEL, iterations=4)
prediction_table_full_updown_erode = cv2.erode(pre_updown[:,:,0], KERNEL, iterations=4)
prediction_table_full_updown_erode = cv2.dilate(prediction_table_full_updown_erode, KERNEL, iterations=4)
prediction_table[:,0:img_w_half,:] = pre1[:,:,:]
prediction_table[:,img_w_half:,:] = pre2[:,:,:]
prediction_table[:,:,0][prediction_table_full_erode[:,:]==1]=1
prediction_table[:,:,0][prediction_table_full_updown_erode[:,:]==1]=1
prediction_table = prediction_table.astype(np.int16)
#prediction_table_erode = cv2.erode(prediction_table[:,:,0], self.kernel, iterations=6)
#prediction_table_erode = cv2.dilate(prediction_table_erode, self.kernel, iterations=6)
prediction_table_erode = cv2.erode(prediction_table[:,:,0], KERNEL, iterations=20)
prediction_table_erode = cv2.dilate(prediction_table_erode, KERNEL, iterations=20)
return prediction_table_erode.astype(np.int16)
def run_graphics_and_columns_light( def run_graphics_and_columns_light(
self, text_regions_p_1, textline_mask_tot_ea, self, text_regions_p_1, textline_mask_tot_ea,
@ -2876,11 +2687,11 @@ class Eynollah:
return (num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_lines, return (num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_lines,
text_regions_p_1, cont_page, table_prediction) text_regions_p_1, cont_page, table_prediction)
def run_enhancement(self, light_version): def run_enhancement(self):
t_in = time.time() t_in = time.time()
self.logger.info("Resizing and enhancing image...") self.logger.info("Resizing and enhancing image...")
is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \ is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \
self.resize_and_enhance_image_with_column_classifier(light_version) self.resize_and_enhance_image_with_column_classifier()
self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ') self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ')
scale = 1 scale = 1
if is_image_enhanced: if is_image_enhanced:
@ -2911,8 +2722,7 @@ class Eynollah:
scaler_h_textline, scaler_h_textline,
scaler_w_textline, scaler_w_textline,
num_col_classifier) num_col_classifier)
if self.textline_light: textline_mask_tot_ea = textline_mask_tot_ea.astype(np.int16)
textline_mask_tot_ea = textline_mask_tot_ea.astype(np.int16)
if self.plotter: if self.plotter:
self.plotter.save_plot_of_textlines(textline_mask_tot_ea, image_page) self.plotter.save_plot_of_textlines(textline_mask_tot_ea, image_page)
@ -2945,7 +2755,7 @@ class Eynollah:
regions_without_separators = regions_without_separators.astype(np.uint8) regions_without_separators = regions_without_separators.astype(np.uint8)
text_regions_p = get_marginals( text_regions_p = get_marginals(
rotate_image(regions_without_separators, slope_deskew), text_regions_p, rotate_image(regions_without_separators, slope_deskew), text_regions_p,
num_col_classifier, slope_deskew, light_version=self.light_version, kernel=KERNEL) num_col_classifier, slope_deskew, kernel=KERNEL)
except Exception as e: except Exception as e:
self.logger.error("exception %s", e) self.logger.error("exception %s", e)
@ -3004,20 +2814,6 @@ class Eynollah:
self.logger.debug("len(boxes): %s", len(boxes)) self.logger.debug("len(boxes): %s", len(boxes))
#print(time.time()-t_0_box,'time box in 3.1') #print(time.time()-t_0_box,'time box in 3.1')
if self.tables:
if self.light_version:
pass
else:
text_regions_p_tables = np.copy(text_regions_p)
text_regions_p_tables[(table_prediction == 1)] = 10
pixel_line = 3
img_revised_tab2 = self.add_tables_heuristic_to_layout(
text_regions_p_tables, boxes, 0, splitter_y_new, peaks_neg_tot_tables, text_regions_p_tables,
num_col_classifier , 0.000005, pixel_line)
#print(time.time()-t_0_box,'time box in 3.2')
img_revised_tab2, contoures_tables = self.check_iou_of_bounding_box_and_contour_for_tables(
img_revised_tab2, table_prediction, 10, num_col_classifier)
#print(time.time()-t_0_box,'time box in 3.3')
else: else:
boxes_d, peaks_neg_tot_tables_d = return_boxes_of_images_by_order_of_reading_new( boxes_d, peaks_neg_tot_tables_d = return_boxes_of_images_by_order_of_reading_new(
splitter_y_new_d, regions_without_separators_d, matrix_of_lines_ch_d, splitter_y_new_d, regions_without_separators_d, matrix_of_lines_ch_d,
@ -3025,63 +2821,24 @@ class Eynollah:
boxes = None boxes = None
self.logger.debug("len(boxes): %s", len(boxes_d)) self.logger.debug("len(boxes): %s", len(boxes_d))
if self.tables:
if self.light_version:
pass
else:
text_regions_p_tables = np.copy(text_regions_p_1_n)
text_regions_p_tables = np.round(text_regions_p_tables)
text_regions_p_tables[(text_regions_p_tables != 3) & (table_prediction_n == 1)] = 10
pixel_line = 3
img_revised_tab2 = self.add_tables_heuristic_to_layout(
text_regions_p_tables, boxes_d, 0, splitter_y_new_d,
peaks_neg_tot_tables_d, text_regions_p_tables,
num_col_classifier, 0.000005, pixel_line)
img_revised_tab2_d,_ = self.check_iou_of_bounding_box_and_contour_for_tables(
img_revised_tab2, table_prediction_n, 10, num_col_classifier)
img_revised_tab2_d_rotated = rotate_image(img_revised_tab2_d, -slope_deskew)
img_revised_tab2_d_rotated = np.round(img_revised_tab2_d_rotated)
img_revised_tab2_d_rotated = img_revised_tab2_d_rotated.astype(np.int8)
img_revised_tab2_d_rotated = resize_image(img_revised_tab2_d_rotated,
text_regions_p.shape[0], text_regions_p.shape[1])
#print(time.time()-t_0_box,'time box in 4') #print(time.time()-t_0_box,'time box in 4')
self.logger.info("detecting boxes took %.1fs", time.time() - t1) self.logger.info("detecting boxes took %.1fs", time.time() - t1)
if self.tables: if self.tables:
if self.light_version: text_regions_p[table_prediction == 1] = 10
text_regions_p[table_prediction == 1] = 10 img_revised_tab = text_regions_p[:,:]
img_revised_tab = text_regions_p[:,:]
else:
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
img_revised_tab = np.copy(img_revised_tab2)
img_revised_tab[(text_regions_p == 1) & (img_revised_tab != 10)] = 1
else:
img_revised_tab = np.copy(text_regions_p)
img_revised_tab[img_revised_tab == 10] = 0
img_revised_tab[img_revised_tab2_d_rotated == 10] = 10
text_regions_p[text_regions_p == 10] = 0
text_regions_p[img_revised_tab == 10] = 10
else: else:
img_revised_tab = text_regions_p[:,:] img_revised_tab = text_regions_p[:,:]
#img_revised_tab = text_regions_p[:, :] #img_revised_tab = text_regions_p[:, :]
if self.light_version: polygons_of_images = return_contours_of_interested_region(text_regions_p, 2)
polygons_of_images = return_contours_of_interested_region(text_regions_p, 2)
else:
polygons_of_images = return_contours_of_interested_region(img_revised_tab, 2)
pixel_img = 4 pixel_img = 4
min_area_mar = 0.00001 min_area_mar = 0.00001
if self.light_version: marginal_mask = (text_regions_p[:,:]==pixel_img)*1
marginal_mask = (text_regions_p[:,:]==pixel_img)*1 marginal_mask = marginal_mask.astype('uint8')
marginal_mask = marginal_mask.astype('uint8') marginal_mask = cv2.dilate(marginal_mask, KERNEL, iterations=2)
marginal_mask = cv2.dilate(marginal_mask, KERNEL, iterations=2)
polygons_of_marginals = return_contours_of_interested_region(marginal_mask, 1, min_area_mar) polygons_of_marginals = return_contours_of_interested_region(marginal_mask, 1, min_area_mar)
else:
polygons_of_marginals = return_contours_of_interested_region(text_regions_p, pixel_img, min_area_mar)
pixel_img = 10 pixel_img = 10
contours_tables = return_contours_of_interested_region(text_regions_p, pixel_img, min_area_mar) contours_tables = return_contours_of_interested_region(text_regions_p, pixel_img, min_area_mar)
@ -3099,144 +2856,43 @@ class Eynollah:
self.logger.debug('enter run_boxes_full_layout') self.logger.debug('enter run_boxes_full_layout')
t_full0 = time.time() t_full0 = time.time()
if self.tables: if self.tables:
if self.light_version: text_regions_p[:,:][table_prediction[:,:]==1] = 10
text_regions_p[:,:][table_prediction[:,:]==1] = 10 img_revised_tab = text_regions_p[:,:]
img_revised_tab = text_regions_p[:,:] if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
if np.abs(slope_deskew) >= SLOPE_THRESHOLD: _, textline_mask_tot_d, text_regions_p_1_n, table_prediction_n = \
_, textline_mask_tot_d, text_regions_p_1_n, table_prediction_n = \ rotation_not_90_func(image_page, textline_mask_tot, text_regions_p,
rotation_not_90_func(image_page, textline_mask_tot, text_regions_p, table_prediction, slope_deskew)
table_prediction, slope_deskew)
text_regions_p_1_n = resize_image(text_regions_p_1_n, text_regions_p_1_n = resize_image(text_regions_p_1_n,
text_regions_p.shape[0], text_regions_p.shape[0],
text_regions_p.shape[1]) text_regions_p.shape[1])
textline_mask_tot_d = resize_image(textline_mask_tot_d, textline_mask_tot_d = resize_image(textline_mask_tot_d,
text_regions_p.shape[0], text_regions_p.shape[0],
text_regions_p.shape[1]) text_regions_p.shape[1])
table_prediction_n = resize_image(table_prediction_n, table_prediction_n = resize_image(table_prediction_n,
text_regions_p.shape[0], text_regions_p.shape[0],
text_regions_p.shape[1]) text_regions_p.shape[1])
regions_without_separators_d = (text_regions_p_1_n[:,:] == 1)*1
regions_without_separators_d[table_prediction_n[:,:] == 1] = 1
else:
text_regions_p_1_n = None
textline_mask_tot_d = None
regions_without_separators_d = None
# regions_without_separators = ( text_regions_p[:,:]==1 | text_regions_p[:,:]==2 )*1
#self.return_regions_without_separators_new(text_regions_p[:,:,0],img_only_regions)
regions_without_separators = (text_regions_p[:,:] == 1)*1
regions_without_separators[table_prediction == 1] = 1
regions_without_separators_d = (text_regions_p_1_n[:,:] == 1)*1
regions_without_separators_d[table_prediction_n[:,:] == 1] = 1
else: else:
if np.abs(slope_deskew) >= SLOPE_THRESHOLD: text_regions_p_1_n = None
_, textline_mask_tot_d, text_regions_p_1_n, table_prediction_n = \ textline_mask_tot_d = None
rotation_not_90_func(image_page, textline_mask_tot, text_regions_p, regions_without_separators_d = None
table_prediction, slope_deskew) # regions_without_separators = ( text_regions_p[:,:]==1 | text_regions_p[:,:]==2 )*1
#self.return_regions_without_separators_new(text_regions_p[:,:,0],img_only_regions)
regions_without_separators = (text_regions_p[:,:] == 1)*1
regions_without_separators[table_prediction == 1] = 1
text_regions_p_1_n = resize_image(text_regions_p_1_n,
text_regions_p.shape[0],
text_regions_p.shape[1])
textline_mask_tot_d = resize_image(textline_mask_tot_d,
text_regions_p.shape[0],
text_regions_p.shape[1])
table_prediction_n = resize_image(table_prediction_n,
text_regions_p.shape[0],
text_regions_p.shape[1])
regions_without_separators_d = (text_regions_p_1_n[:,:] == 1)*1
regions_without_separators_d[table_prediction_n[:,:] == 1] = 1
else:
text_regions_p_1_n = None
textline_mask_tot_d = None
regions_without_separators_d = None
# regions_without_separators = ( text_regions_p[:,:]==1 | text_regions_p[:,:]==2 )*1
#self.return_regions_without_separators_new(text_regions_p[:,:,0],img_only_regions)
regions_without_separators = (text_regions_p[:,:] == 1)*1
regions_without_separators[table_prediction == 1] = 1
pixel_lines=3
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
num_col, _, matrix_of_lines_ch, splitter_y_new, _ = find_number_of_columns_in_document(
text_regions_p, num_col_classifier, self.tables, pixel_lines)
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
num_col_d, _, matrix_of_lines_ch_d, splitter_y_new_d, _ = find_number_of_columns_in_document(
text_regions_p_1_n, num_col_classifier, self.tables, pixel_lines)
if num_col_classifier>=3:
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
regions_without_separators = regions_without_separators.astype(np.uint8)
regions_without_separators = cv2.erode(regions_without_separators[:,:], KERNEL, iterations=6)
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
regions_without_separators_d = regions_without_separators_d.astype(np.uint8)
regions_without_separators_d = cv2.erode(regions_without_separators_d[:,:], KERNEL, iterations=6)
else:
pass
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
boxes, peaks_neg_tot_tables = return_boxes_of_images_by_order_of_reading_new(
splitter_y_new, regions_without_separators, matrix_of_lines_ch,
num_col_classifier, erosion_hurts, self.tables, self.right2left)
text_regions_p_tables = np.copy(text_regions_p)
text_regions_p_tables[:,:][(table_prediction[:,:]==1)] = 10
pixel_line = 3
img_revised_tab2 = self.add_tables_heuristic_to_layout(
text_regions_p_tables, boxes, 0, splitter_y_new, peaks_neg_tot_tables, text_regions_p_tables,
num_col_classifier , 0.000005, pixel_line)
img_revised_tab2,contoures_tables = self.check_iou_of_bounding_box_and_contour_for_tables(
img_revised_tab2, table_prediction, 10, num_col_classifier)
else:
boxes_d, peaks_neg_tot_tables_d = return_boxes_of_images_by_order_of_reading_new(
splitter_y_new_d, regions_without_separators_d, matrix_of_lines_ch_d,
num_col_classifier, erosion_hurts, self.tables, self.right2left)
text_regions_p_tables = np.copy(text_regions_p_1_n)
text_regions_p_tables = np.round(text_regions_p_tables)
text_regions_p_tables[(text_regions_p_tables != 3) & (table_prediction_n == 1)] = 10
pixel_line = 3
img_revised_tab2 = self.add_tables_heuristic_to_layout(
text_regions_p_tables, boxes_d, 0, splitter_y_new_d,
peaks_neg_tot_tables_d, text_regions_p_tables,
num_col_classifier, 0.000005, pixel_line)
img_revised_tab2_d,_ = self.check_iou_of_bounding_box_and_contour_for_tables(
img_revised_tab2, table_prediction_n, 10, num_col_classifier)
img_revised_tab2_d_rotated = rotate_image(img_revised_tab2_d, -slope_deskew)
img_revised_tab2_d_rotated = np.round(img_revised_tab2_d_rotated)
img_revised_tab2_d_rotated = img_revised_tab2_d_rotated.astype(np.int8)
img_revised_tab2_d_rotated = resize_image(img_revised_tab2_d_rotated,
text_regions_p.shape[0],
text_regions_p.shape[1])
if np.abs(slope_deskew) < 0.13:
img_revised_tab = np.copy(img_revised_tab2)
else:
img_revised_tab = np.copy(text_regions_p)
img_revised_tab[img_revised_tab == 10] = 0
img_revised_tab[img_revised_tab2_d_rotated == 10] = 10
##img_revised_tab = img_revised_tab2[:,:]
#img_revised_tab = text_regions_p[:,:]
text_regions_p[text_regions_p == 10] = 0
text_regions_p[img_revised_tab == 10] = 10
#img_revised_tab[img_revised_tab2 == 10] = 10
pixel_img = 4 pixel_img = 4
min_area_mar = 0.00001 min_area_mar = 0.00001
if self.light_version: marginal_mask = (text_regions_p[:,:]==pixel_img)*1
marginal_mask = (text_regions_p[:,:]==pixel_img)*1 marginal_mask = marginal_mask.astype('uint8')
marginal_mask = marginal_mask.astype('uint8') marginal_mask = cv2.dilate(marginal_mask, KERNEL, iterations=2)
marginal_mask = cv2.dilate(marginal_mask, KERNEL, iterations=2)
polygons_of_marginals = return_contours_of_interested_region(marginal_mask, 1, min_area_mar) polygons_of_marginals = return_contours_of_interested_region(marginal_mask, 1, min_area_mar)
else:
polygons_of_marginals = return_contours_of_interested_region(text_regions_p, pixel_img, min_area_mar)
pixel_img = 10 pixel_img = 10
contours_tables = return_contours_of_interested_region(text_regions_p, pixel_img, min_area_mar) contours_tables = return_contours_of_interested_region(text_regions_p, pixel_img, min_area_mar)
@ -3249,7 +2905,7 @@ class Eynollah:
image_page = image_page.astype(np.uint8) image_page = image_page.astype(np.uint8)
#print("full inside 1", time.time()- t_full0) #print("full inside 1", time.time()- t_full0)
regions_fully, regions_fully_only_drop = self.extract_text_regions_new( regions_fully, regions_fully_only_drop = self.extract_text_regions_new(
img_bin_light if self.light_version else image_page, img_bin_light,
False, cols=num_col_classifier) False, cols=num_col_classifier)
#print("full inside 2", time.time()- t_full0) #print("full inside 2", time.time()- t_full0)
# 6 is the separators lable in old full layout model # 6 is the separators lable in old full layout model
@ -3333,7 +2989,7 @@ class Eynollah:
min_cont_size_to_be_dilated = 10 min_cont_size_to_be_dilated = 10
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
(cx_conts, cy_conts, (cx_conts, cy_conts,
x_min_conts, x_max_conts, x_min_conts, x_max_conts,
y_min_conts, y_max_conts, y_min_conts, y_max_conts,
@ -3447,13 +3103,13 @@ class Eynollah:
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12, img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,
int(x_min_main[j]):int(x_max_main[j])] = 1 int(x_min_main[j]):int(x_max_main[j])] = 1
co_text_all_org = contours_only_text_parent + contours_only_text_parent_h co_text_all_org = contours_only_text_parent + contours_only_text_parent_h
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
co_text_all = contours_only_dilated + contours_only_text_parent_h co_text_all = contours_only_dilated + contours_only_text_parent_h
else: else:
co_text_all = contours_only_text_parent + contours_only_text_parent_h co_text_all = contours_only_text_parent + contours_only_text_parent_h
else: else:
co_text_all_org = contours_only_text_parent co_text_all_org = contours_only_text_parent
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
co_text_all = contours_only_dilated co_text_all = contours_only_dilated
else: else:
co_text_all = contours_only_text_parent co_text_all = contours_only_text_parent
@ -3528,7 +3184,7 @@ class Eynollah:
ordered = [i[0] for i in ordered] ordered = [i[0] for i in ordered]
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
org_contours_indexes = [] org_contours_indexes = []
for ind in range(len(ordered)): for ind in range(len(ordered)):
region_with_curr_order = ordered[ind] region_with_curr_order = ordered[ind]
@ -3788,10 +3444,6 @@ class Eynollah:
# Log enabled features directly # Log enabled features directly
enabled_modes = [] enabled_modes = []
if self.light_version:
enabled_modes.append("Light version")
if self.textline_light:
enabled_modes.append("Light textline detection")
if self.full_layout: if self.full_layout:
enabled_modes.append("Full layout analysis") enabled_modes.append("Full layout analysis")
if self.tables: if self.tables:
@ -3851,7 +3503,7 @@ class Eynollah:
self.logger.info("Step 1/5: Image Enhancement") self.logger.info("Step 1/5: Image Enhancement")
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = \ img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = \
self.run_enhancement(self.light_version) self.run_enhancement()
self.logger.info(f"Image: {self.image.shape[1]}x{self.image.shape[0]}, " self.logger.info(f"Image: {self.image.shape[1]}x{self.image.shape[0]}, "
f"{self.dpi} DPI, {num_col_classifier} columns") f"{self.dpi} DPI, {num_col_classifier} columns")
@ -3928,49 +3580,34 @@ class Eynollah:
t1 = time.time() t1 = time.time()
self.logger.info("Step 2/5: Layout Analysis") self.logger.info("Step 2/5: Layout Analysis")
if self.light_version: self.logger.info("Using light version processing")
self.logger.info("Using light version processing") text_regions_p_1 ,erosion_hurts, polygons_seplines, polygons_text_early, \
text_regions_p_1 ,erosion_hurts, polygons_seplines, polygons_text_early, \ textline_mask_tot_ea, img_bin_light, confidence_matrix = \
textline_mask_tot_ea, img_bin_light, confidence_matrix = \ self.get_regions_light_v(img_res, is_image_enhanced, num_col_classifier)
self.get_regions_light_v(img_res, is_image_enhanced, num_col_classifier) #print("text region early -2 in %.1fs", time.time() - t0)
#print("text region early -2 in %.1fs", time.time() - t0)
if num_col_classifier == 1 or num_col_classifier ==2: if num_col_classifier == 1 or num_col_classifier ==2:
if num_col_classifier == 1: if num_col_classifier == 1:
img_w_new = 1000 img_w_new = 1000
else:
img_w_new = 1300
img_h_new = img_w_new * textline_mask_tot_ea.shape[0] // textline_mask_tot_ea.shape[1]
textline_mask_tot_ea_deskew = resize_image(textline_mask_tot_ea,img_h_new, img_w_new )
slope_deskew = self.run_deskew(textline_mask_tot_ea_deskew)
else: else:
slope_deskew = self.run_deskew(textline_mask_tot_ea) img_w_new = 1300
#print("text region early -2,5 in %.1fs", time.time() - t0) img_h_new = img_w_new * textline_mask_tot_ea.shape[0] // textline_mask_tot_ea.shape[1]
#self.logger.info("Textregion detection took %.1fs ", time.time() - t1t)
num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_lines, \
text_regions_p_1, cont_page, table_prediction, textline_mask_tot_ea, img_bin_light = \
self.run_graphics_and_columns_light(text_regions_p_1, textline_mask_tot_ea,
num_col_classifier, num_column_is_classified,
erosion_hurts, img_bin_light)
#self.logger.info("run graphics %.1fs ", time.time() - t1t)
#print("text region early -3 in %.1fs", time.time() - t0)
textline_mask_tot_ea_org = np.copy(textline_mask_tot_ea)
textline_mask_tot_ea_deskew = resize_image(textline_mask_tot_ea,img_h_new, img_w_new )
slope_deskew = self.run_deskew(textline_mask_tot_ea_deskew)
else: else:
text_regions_p_1, erosion_hurts, polygons_seplines, polygons_text_early = \ slope_deskew = self.run_deskew(textline_mask_tot_ea)
self.get_regions_from_xy_2models(img_res, is_image_enhanced, #print("text region early -2,5 in %.1fs", time.time() - t0)
num_col_classifier) #self.logger.info("Textregion detection took %.1fs ", time.time() - t1t)
self.logger.info(f"Textregion detection took {time.time() - t1:.1f}s") num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_lines, \
confidence_matrix = np.zeros((text_regions_p_1.shape[:2])) text_regions_p_1, cont_page, table_prediction, textline_mask_tot_ea, img_bin_light = \
self.run_graphics_and_columns_light(text_regions_p_1, textline_mask_tot_ea,
num_col_classifier, num_column_is_classified,
erosion_hurts, img_bin_light)
#self.logger.info("run graphics %.1fs ", time.time() - t1t)
#print("text region early -3 in %.1fs", time.time() - t0)
textline_mask_tot_ea_org = np.copy(textline_mask_tot_ea)
t1 = time.time()
num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_lines, \
text_regions_p_1, cont_page, table_prediction = \
self.run_graphics_and_columns(text_regions_p_1, num_col_classifier, num_column_is_classified,
erosion_hurts)
self.logger.info(f"Graphics detection took {time.time() - t1:.1f}s")
#self.logger.info('cont_page %s', cont_page)
#plt.imshow(table_prediction) #plt.imshow(table_prediction)
#plt.show() #plt.show()
self.logger.info(f"Layout analysis complete ({time.time() - t1:.1f}s)") self.logger.info(f"Layout analysis complete ({time.time() - t1:.1f}s)")
@ -3985,13 +3622,7 @@ class Eynollah:
#print("text region early in %.1fs", time.time() - t0) #print("text region early in %.1fs", time.time() - t0)
t1 = time.time() t1 = time.time()
if not self.light_version: if num_col_classifier in (1,2):
textline_mask_tot_ea = self.run_textline(image_page)
self.logger.info(f"Textline detection took {time.time() - t1:.1f}s")
t1 = time.time()
slope_deskew = self.run_deskew(textline_mask_tot_ea)
self.logger.info(f"Deskewing took {time.time() - t1:.1f}s")
elif num_col_classifier in (1,2):
org_h_l_m = textline_mask_tot_ea.shape[0] org_h_l_m = textline_mask_tot_ea.shape[0]
org_w_l_m = textline_mask_tot_ea.shape[1] org_w_l_m = textline_mask_tot_ea.shape[1]
if num_col_classifier == 1: if num_col_classifier == 1:
@ -4030,10 +3661,8 @@ class Eynollah:
if self.curved_line: if self.curved_line:
self.logger.info("Mode: Curved line detection") self.logger.info("Mode: Curved line detection")
elif self.textline_light:
self.logger.info("Mode: Light detection")
if self.light_version and num_col_classifier in (1,2): if num_col_classifier in (1,2):
image_page = resize_image(image_page,org_h_l_m, org_w_l_m ) image_page = resize_image(image_page,org_h_l_m, org_w_l_m )
textline_mask_tot_ea = resize_image(textline_mask_tot_ea,org_h_l_m, org_w_l_m ) textline_mask_tot_ea = resize_image(textline_mask_tot_ea,org_h_l_m, org_w_l_m )
text_regions_p = resize_image(text_regions_p,org_h_l_m, org_w_l_m ) text_regions_p = resize_image(text_regions_p,org_h_l_m, org_w_l_m )
@ -4057,11 +3686,10 @@ class Eynollah:
regions_fully, regions_without_separators, polygons_of_marginals, contours_tables = \ regions_fully, regions_without_separators, polygons_of_marginals, contours_tables = \
self.run_boxes_full_layout(image_page, textline_mask_tot, text_regions_p, slope_deskew, self.run_boxes_full_layout(image_page, textline_mask_tot, text_regions_p, slope_deskew,
num_col_classifier, img_only_regions, table_prediction, erosion_hurts, num_col_classifier, img_only_regions, table_prediction, erosion_hurts,
img_bin_light if self.light_version else None) img_bin_light)
###polygons_of_marginals = dilate_textregion_contours(polygons_of_marginals) ###polygons_of_marginals = dilate_textregion_contours(polygons_of_marginals)
if self.light_version: drop_label_in_full_layout = 4
drop_label_in_full_layout = 4 textline_mask_tot_ea_org[img_revised_tab==drop_label_in_full_layout] = 0
textline_mask_tot_ea_org[img_revised_tab==drop_label_in_full_layout] = 0
text_only = (img_revised_tab[:, :] == 1) * 1 text_only = (img_revised_tab[:, :] == 1) * 1
@ -4222,68 +3850,40 @@ class Eynollah:
#print("text region early 3 in %.1fs", time.time() - t0) #print("text region early 3 in %.1fs", time.time() - t0)
if self.light_version: contours_only_text_parent = dilate_textregion_contours(contours_only_text_parent)
contours_only_text_parent = dilate_textregion_contours(contours_only_text_parent) contours_only_text_parent , contours_only_text_parent_d_ordered = self.filter_contours_inside_a_bigger_one(
contours_only_text_parent , contours_only_text_parent_d_ordered = self.filter_contours_inside_a_bigger_one( contours_only_text_parent, contours_only_text_parent_d_ordered, text_only,
contours_only_text_parent, contours_only_text_parent_d_ordered, text_only, marginal_cnts=polygons_of_marginals)
marginal_cnts=polygons_of_marginals) #print("text region early 3.5 in %.1fs", time.time() - t0)
#print("text region early 3.5 in %.1fs", time.time() - t0) conf_contours_textregions = get_textregion_contours_in_org_image_light(
conf_contours_textregions = get_textregion_contours_in_org_image_light( contours_only_text_parent, self.image, confidence_matrix)
contours_only_text_parent, self.image, confidence_matrix) #contours_only_text_parent = dilate_textregion_contours(contours_only_text_parent)
#contours_only_text_parent = dilate_textregion_contours(contours_only_text_parent)
else:
conf_contours_textregions = get_textregion_contours_in_org_image_light(
contours_only_text_parent, self.image, confidence_matrix)
#print("text region early 4 in %.1fs", time.time() - t0) #print("text region early 4 in %.1fs", time.time() - t0)
boxes_text = get_text_region_boxes_by_given_contours(contours_only_text_parent) boxes_text = get_text_region_boxes_by_given_contours(contours_only_text_parent)
boxes_marginals = get_text_region_boxes_by_given_contours(polygons_of_marginals) boxes_marginals = get_text_region_boxes_by_given_contours(polygons_of_marginals)
#print("text region early 5 in %.1fs", time.time() - t0) #print("text region early 5 in %.1fs", time.time() - t0)
## birdan sora chock chakir ## birdan sora chock chakir
if not self.curved_line: if not self.curved_line:
if self.light_version: all_found_textline_polygons, \
if self.textline_light: all_box_coord, slopes = self.get_slopes_and_deskew_new_light2(
all_found_textline_polygons, \ contours_only_text_parent, textline_mask_tot_ea_org,
all_box_coord, slopes = self.get_slopes_and_deskew_new_light2( boxes_text, slope_deskew)
contours_only_text_parent, textline_mask_tot_ea_org, all_found_textline_polygons_marginals, \
boxes_text, slope_deskew) all_box_coord_marginals, slopes_marginals = self.get_slopes_and_deskew_new_light2(
all_found_textline_polygons_marginals, \ polygons_of_marginals, textline_mask_tot_ea_org,
all_box_coord_marginals, slopes_marginals = self.get_slopes_and_deskew_new_light2( boxes_marginals, slope_deskew)
polygons_of_marginals, textline_mask_tot_ea_org,
boxes_marginals, slope_deskew)
all_found_textline_polygons = dilate_textline_contours( all_found_textline_polygons = dilate_textline_contours(
all_found_textline_polygons) all_found_textline_polygons)
all_found_textline_polygons = self.filter_contours_inside_a_bigger_one( all_found_textline_polygons = self.filter_contours_inside_a_bigger_one(
all_found_textline_polygons, None, textline_mask_tot_ea_org, type_contour="textline") all_found_textline_polygons, None, textline_mask_tot_ea_org, type_contour="textline")
all_found_textline_polygons_marginals = dilate_textline_contours( all_found_textline_polygons_marginals = dilate_textline_contours(
all_found_textline_polygons_marginals) all_found_textline_polygons_marginals)
contours_only_text_parent, all_found_textline_polygons, \ contours_only_text_parent, all_found_textline_polygons, \
contours_only_text_parent_d_ordered, conf_contours_textregions = \ contours_only_text_parent_d_ordered, conf_contours_textregions = \
self.filter_contours_without_textline_inside( self.filter_contours_without_textline_inside(
contours_only_text_parent, all_found_textline_polygons, contours_only_text_parent, all_found_textline_polygons,
contours_only_text_parent_d_ordered, conf_contours_textregions) contours_only_text_parent_d_ordered, conf_contours_textregions)
else:
textline_mask_tot_ea = cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=1)
all_found_textline_polygons, \
all_box_coord, slopes = self.get_slopes_and_deskew_new_light(
contours_only_text_parent, contours_only_text_parent, textline_mask_tot_ea,
boxes_text, slope_deskew)
all_found_textline_polygons_marginals, \
all_box_coord_marginals, slopes_marginals = self.get_slopes_and_deskew_new_light(
polygons_of_marginals, polygons_of_marginals, textline_mask_tot_ea,
boxes_marginals, slope_deskew)
#all_found_textline_polygons = self.filter_contours_inside_a_bigger_one(
# all_found_textline_polygons, textline_mask_tot_ea_org, type_contour="textline")
else:
textline_mask_tot_ea = cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=1)
all_found_textline_polygons, \
all_box_coord, slopes = self.get_slopes_and_deskew_new(
contours_only_text_parent, contours_only_text_parent, textline_mask_tot_ea,
boxes_text, slope_deskew)
all_found_textline_polygons_marginals, \
all_box_coord_marginals, slopes_marginals = self.get_slopes_and_deskew_new(
polygons_of_marginals, polygons_of_marginals, textline_mask_tot_ea,
boxes_marginals, slope_deskew)
else: else:
scale_param = 1 scale_param = 1
textline_mask_tot_ea_erode = cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=2) textline_mask_tot_ea_erode = cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=2)
@ -4314,10 +3914,7 @@ class Eynollah:
#print(len(polygons_of_marginals), len(ordered_left_marginals), len(ordered_right_marginals), 'marginals ordred') #print(len(polygons_of_marginals), len(ordered_left_marginals), len(ordered_right_marginals), 'marginals ordred')
if self.full_layout: if self.full_layout:
if self.light_version: fun = check_any_text_region_in_model_one_is_main_or_header_light
fun = check_any_text_region_in_model_one_is_main_or_header_light
else:
fun = check_any_text_region_in_model_one_is_main_or_header
text_regions_p, contours_only_text_parent, contours_only_text_parent_h, all_box_coord, all_box_coord_h, \ text_regions_p, contours_only_text_parent, contours_only_text_parent_h, all_box_coord, all_box_coord_h, \
all_found_textline_polygons, all_found_textline_polygons_h, slopes, slopes_h, \ all_found_textline_polygons, all_found_textline_polygons_h, slopes, slopes_h, \
contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered, \ contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered, \
@ -4336,7 +3933,7 @@ class Eynollah:
##all_found_textline_polygons = adhere_drop_capital_region_into_corresponding_textline( ##all_found_textline_polygons = adhere_drop_capital_region_into_corresponding_textline(
##text_regions_p, polygons_of_drop_capitals, contours_only_text_parent, contours_only_text_parent_h, ##text_regions_p, polygons_of_drop_capitals, contours_only_text_parent, contours_only_text_parent_h,
##all_box_coord, all_box_coord_h, all_found_textline_polygons, all_found_textline_polygons_h, ##all_box_coord, all_box_coord_h, all_found_textline_polygons, all_found_textline_polygons_h,
##kernel=KERNEL, curved_line=self.curved_line, textline_light=self.textline_light) ##kernel=KERNEL, curved_line=self.curved_line)
if not self.reading_order_machine_based: if not self.reading_order_machine_based:
label_seps = 6 label_seps = 6

View file

@ -43,7 +43,6 @@ class Enhancer:
save_org_scale : bool = False, save_org_scale : bool = False,
): ):
self.input_binary = False self.input_binary = False
self.light_version = False
self.save_org_scale = save_org_scale self.save_org_scale = save_org_scale
if num_col_upper: if num_col_upper:
self.num_col_upper = int(num_col_upper) self.num_col_upper = int(num_col_upper)
@ -69,16 +68,10 @@ class Enhancer:
ret = {} ret = {}
if image_filename: if image_filename:
ret['img'] = cv2.imread(image_filename) ret['img'] = cv2.imread(image_filename)
if self.light_version: self.dpi = 100
self.dpi = 100
else:
self.dpi = 0#check_dpi(image_filename)
else: else:
ret['img'] = pil2cv(image_pil) ret['img'] = pil2cv(image_pil)
if self.light_version: self.dpi = 100
self.dpi = 100
else:
self.dpi = 0#check_dpi(image_pil)
ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY) ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY)
for prefix in ('', '_grayscale'): for prefix in ('', '_grayscale'):
ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8) ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8)
@ -271,7 +264,7 @@ class Enhancer:
return img_new, num_column_is_classified return img_new, num_column_is_classified
def resize_and_enhance_image_with_column_classifier(self, light_version): def resize_and_enhance_image_with_column_classifier(self):
self.logger.debug("enter resize_and_enhance_image_with_column_classifier") self.logger.debug("enter resize_and_enhance_image_with_column_classifier")
dpi = 0#self.dpi dpi = 0#self.dpi
self.logger.info("Detected %s DPI", dpi) self.logger.info("Detected %s DPI", dpi)
@ -354,16 +347,13 @@ class Enhancer:
self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5)) self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))
if dpi < DPI_THRESHOLD: if dpi < DPI_THRESHOLD:
if light_version and num_col in (1,2): if num_col in (1,2):
img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2(
img, num_col, width_early, label_p_pred) img, num_col, width_early, label_p_pred)
else: else:
img_new, num_column_is_classified = self.calculate_width_height_by_columns( img_new, num_column_is_classified = self.calculate_width_height_by_columns(
img, num_col, width_early, label_p_pred) img, num_col, width_early, label_p_pred)
if light_version: image_res = np.copy(img_new)
image_res = np.copy(img_new)
else:
image_res = self.predict_enhancement(img_new)
is_image_enhanced = True is_image_enhanced = True
else: else:
@ -657,11 +647,11 @@ class Enhancer:
gc.collect() gc.collect()
return prediction_true return prediction_true
def run_enhancement(self, light_version): def run_enhancement(self):
t_in = time.time() t_in = time.time()
self.logger.info("Resizing and enhancing image...") self.logger.info("Resizing and enhancing image...")
is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \ is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \
self.resize_and_enhance_image_with_column_classifier(light_version) self.resize_and_enhance_image_with_column_classifier()
self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ') self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ')
return img_res, is_image_enhanced, num_col_classifier, num_column_is_classified return img_res, is_image_enhanced, num_col_classifier, num_column_is_classified
@ -669,7 +659,7 @@ class Enhancer:
def run_single(self): def run_single(self):
t0 = time.time() t0 = time.time()
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False) img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement()
return img_res, is_image_enhanced return img_res, is_image_enhanced

View file

@ -49,8 +49,6 @@ class machine_based_reading_order_on_layout:
self.logger.warning("no GPU device available") self.logger.warning("no GPU device available")
self.model_zoo.load_model('reading_order') self.model_zoo.load_model('reading_order')
# FIXME: light_version is always true, no need for checks in the code
self.light_version = True
def read_xml(self, xml_file): def read_xml(self, xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
@ -517,7 +515,7 @@ class machine_based_reading_order_on_layout:
min_cont_size_to_be_dilated = 10 min_cont_size_to_be_dilated = 10
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _ = find_new_features_of_contours(contours_only_text_parent) cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _ = find_new_features_of_contours(contours_only_text_parent)
args_cont_located = np.array(range(len(contours_only_text_parent))) args_cont_located = np.array(range(len(contours_only_text_parent)))
@ -617,13 +615,13 @@ class machine_based_reading_order_on_layout:
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12, img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,
int(x_min_main[j]):int(x_max_main[j])] = 1 int(x_min_main[j]):int(x_max_main[j])] = 1
co_text_all_org = contours_only_text_parent + contours_only_text_parent_h co_text_all_org = contours_only_text_parent + contours_only_text_parent_h
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
co_text_all = contours_only_dilated + contours_only_text_parent_h co_text_all = contours_only_dilated + contours_only_text_parent_h
else: else:
co_text_all = contours_only_text_parent + contours_only_text_parent_h co_text_all = contours_only_text_parent + contours_only_text_parent_h
else: else:
co_text_all_org = contours_only_text_parent co_text_all_org = contours_only_text_parent
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
co_text_all = contours_only_dilated co_text_all = contours_only_dilated
else: else:
co_text_all = contours_only_text_parent co_text_all = contours_only_text_parent
@ -702,7 +700,7 @@ class machine_based_reading_order_on_layout:
##id_all_text = np.array(id_all_text)[index_sort] ##id_all_text = np.array(id_all_text)[index_sort]
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
org_contours_indexes = [] org_contours_indexes = []
for ind in range(len(ordered)): for ind in range(len(ordered)):
region_with_curr_order = ordered[ind] region_with_curr_order = ordered[ind]

View file

@ -13,7 +13,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
category="enhancement", category="enhancement",
variant='', variant='',
filename="models_eynollah/eynollah-enhancement_20210425", filename="models_eynollah/eynollah-enhancement_20210425",
dists=['enhancement', 'layout', 'ci'],
dist_url=dist_url(), dist_url=dist_url(),
type='Keras', type='Keras',
), ),
@ -22,7 +21,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
category="binarization", category="binarization",
variant='hybrid', variant='hybrid',
filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens", filename="models_eynollah/eynollah-binarization-hybrid_20230504/model_bin_hybrid_trans_cnn_sbb_ens",
dists=['layout', 'binarization', ],
dist_url=dist_url(), dist_url=dist_url(),
type='Keras', type='Keras',
), ),
@ -31,7 +29,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
category="binarization", category="binarization",
variant='20210309', variant='20210309',
filename="models_eynollah/eynollah-binarization_20210309", filename="models_eynollah/eynollah-binarization_20210309",
dists=['binarization'],
dist_url=dist_url("extra"), dist_url=dist_url("extra"),
type='Keras', type='Keras',
), ),
@ -40,7 +37,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
category="binarization", category="binarization",
variant='', variant='',
filename="models_eynollah/eynollah-binarization_20210425", filename="models_eynollah/eynollah-binarization_20210425",
dists=['binarization'],
dist_url=dist_url("extra"), dist_url=dist_url("extra"),
type='Keras', type='Keras',
), ),
@ -50,7 +46,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='', variant='',
filename="models_eynollah/eynollah-column-classifier_20210425", filename="models_eynollah/eynollah-column-classifier_20210425",
dist_url=dist_url(), dist_url=dist_url(),
dists=['layout'],
type='Keras', type='Keras',
), ),
@ -59,7 +54,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='', variant='',
filename="models_eynollah/model_eynollah_page_extraction_20250915", filename="models_eynollah/model_eynollah_page_extraction_20250915",
dist_url=dist_url(), dist_url=dist_url(),
dists=['layout'],
type='Keras', type='Keras',
), ),
@ -68,7 +62,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='', variant='',
filename="models_eynollah/eynollah-main-regions-ensembled_20210425", filename="models_eynollah/eynollah-main-regions-ensembled_20210425",
dist_url=dist_url(), dist_url=dist_url(),
dists=['layout'],
type='Keras', type='Keras',
), ),
@ -77,27 +70,24 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='extract_only_images', variant='extract_only_images',
filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18", filename="models_eynollah/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18",
dist_url=dist_url(), dist_url=dist_url(),
dists=['layout'],
type='Keras', type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
category="region", category="region",
variant='light', variant='',
filename="models_eynollah/eynollah-main-regions_20220314", filename="models_eynollah/eynollah-main-regions_20220314",
dist_url=dist_url(), dist_url=dist_url(),
help="early layout", help="early layout",
dists=['layout'],
type='Keras', type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
category="region_p2", category="region_p2",
variant='', variant='non-light',
filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425", filename="models_eynollah/eynollah-main-regions-aug-rotation_20210425",
dist_url=dist_url(), dist_url=dist_url('extra'),
help="early layout, non-light, 2nd part", help="early layout, non-light, 2nd part",
dists=['layout'],
type='Keras', type='Keras',
), ),
@ -110,8 +100,7 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
#filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige", #filename="models_eynollah/modelens_1_2_4_5_early_lay_1_2_spaltige",
#filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige", #filename="models_eynollah/model_3_eraly_layout_no_patches_1_2_spaltige",
filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024", filename="models_eynollah/modelens_e_l_all_sp_0_1_2_3_4_171024",
dist_url=dist_url("all"), dist_url=dist_url("layout"),
dists=['layout'],
help="early layout, light, 1-or-2-column", help="early layout, light, 1-or-2-column",
type='Keras', type='Keras',
), ),
@ -128,7 +117,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/modelens_full_lay_1__4_3_091124", filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist_url=dist_url(), dist_url=dist_url(),
help="full layout / no patches", help="full layout / no patches",
dists=['layout'],
type='Keras', type='Keras',
), ),
@ -148,7 +136,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/modelens_full_lay_1__4_3_091124", filename="models_eynollah/modelens_full_lay_1__4_3_091124",
dist_url=dist_url(), dist_url=dist_url(),
help="full layout / with patches", help="full layout / with patches",
dists=['layout'],
type='Keras', type='Keras',
), ),
@ -162,13 +149,12 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
#filename="models_eynollah/model_ens_reading_order_machine_based", #filename="models_eynollah/model_ens_reading_order_machine_based",
filename="models_eynollah/model_eynollah_reading_order_20250824", filename="models_eynollah/model_eynollah_reading_order_20250824",
dist_url=dist_url(), dist_url=dist_url(),
dists=['layout', 'reading_order'],
type='Keras', type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
category="textline", category="textline",
variant='', variant='non-light',
#filename="models_eynollah/modelens_textline_1_4_16092024", #filename="models_eynollah/modelens_textline_1_4_16092024",
#filename="models_eynollah/model_textline_ens_3_4_5_6_artificial", #filename="models_eynollah/model_textline_ens_3_4_5_6_artificial",
#filename="models_eynollah/modelens_textline_1_3_4_20240915", #filename="models_eynollah/modelens_textline_1_3_4_20240915",
@ -176,36 +162,32 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
#filename="models_eynollah/modelens_textline_9_12_13_14_15", #filename="models_eynollah/modelens_textline_9_12_13_14_15",
#filename="models_eynollah/eynollah-textline_20210425", #filename="models_eynollah/eynollah-textline_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024", filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url(), dist_url=dist_url('extra'),
dists=['layout'],
type='Keras', type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
category="textline", category="textline",
variant='light', variant='',
#filename="models_eynollah/eynollah-textline_light_20210425", #filename="models_eynollah/eynollah-textline_light_20210425",
filename="models_eynollah/modelens_textline_0_1__2_4_16092024", filename="models_eynollah/modelens_textline_0_1__2_4_16092024",
dist_url=dist_url(), dist_url=dist_url(),
dists=['layout'], type='Keras',
),
EynollahModelSpec(
category="table",
variant='non-light',
filename="models_eynollah/eynollah-tables_20210319",
dist_url=dist_url('extra'),
type='Keras', type='Keras',
), ),
EynollahModelSpec( EynollahModelSpec(
category="table", category="table",
variant='', variant='',
filename="models_eynollah/eynollah-tables_20210319",
dist_url=dist_url(),
dists=['layout'],
type='Keras',
),
EynollahModelSpec(
category="table",
variant='light',
filename="models_eynollah/modelens_table_0t4_201124", filename="models_eynollah/modelens_table_0t4_201124",
dist_url=dist_url(), dist_url=dist_url(),
dists=['layout'],
type='Keras', type='Keras',
), ),
@ -214,7 +196,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='', variant='',
filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930", filename="models_eynollah/model_eynollah_ocr_cnnrnn_20250930",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['layout', 'ocr'],
type='Keras', type='Keras',
), ),
@ -224,7 +205,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/", filename="models_eynollah/model_eynollah_ocr_cnnrnn__degraded_20250805/",
help="slightly better at degraded Fraktur", help="slightly better at degraded Fraktur",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['ocr'],
type='Keras', type='Keras',
), ),
@ -233,7 +213,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='', variant='',
filename="characters_org.txt", filename="characters_org.txt",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['ocr'],
type='decoder', type='decoder',
), ),
@ -242,7 +221,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='', variant='',
filename="characters_org.txt", filename="characters_org.txt",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['ocr'],
type='List[str]', type='List[str]',
), ),
@ -252,7 +230,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
filename="models_eynollah/model_eynollah_ocr_trocr_20250919", filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
help='much slower transformer-based', help='much slower transformer-based',
dists=['trocr'],
type='Keras', type='Keras',
), ),
@ -261,7 +238,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='', variant='',
filename="models_eynollah/model_eynollah_ocr_trocr_20250919", filename="models_eynollah/model_eynollah_ocr_trocr_20250919",
dist_url=dist_url("ocr"), dist_url=dist_url("ocr"),
dists=['trocr'],
type='TrOCRProcessor', type='TrOCRProcessor',
), ),
@ -270,7 +246,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
variant='htr', variant='htr',
filename="models_eynollah/microsoft/trocr-base-handwritten", filename="models_eynollah/microsoft/trocr-base-handwritten",
dist_url=dist_url("extra"), dist_url=dist_url("extra"),
dists=['trocr'],
type='TrOCRProcessor', type='TrOCRProcessor',
), ),

View file

@ -176,13 +176,12 @@ class EynollahModelZoo:
spec.category, spec.category,
spec.variant, spec.variant,
spec.help, spec.help,
', '.join(spec.dists),
f'Yes, at {self.model_path(spec.category, spec.variant)}' f'Yes, at {self.model_path(spec.category, spec.variant)}'
if self.model_path(spec.category, spec.variant).exists() if self.model_path(spec.category, spec.variant).exists()
else f'No, download {spec.dist_url}', else f'No, download {spec.dist_url}',
# self.model_path(spec.category, spec.variant), # self.model_path(spec.category, spec.variant),
] ]
for spec in self.specs.specs for spec in sorted(self.specs.specs, key=lambda x: x.dist_url)
], ],
headers=[ headers=[
'Type', 'Type',

View file

@ -10,8 +10,6 @@ class EynollahModelSpec():
category: str category: str
# Relative filename to the models_eynollah directory in the dists # Relative filename to the models_eynollah directory in the dists
filename: str filename: str
# basename of the ZIP files that should contain this model
dists: List[str]
# URL to the smallest model distribution containing this model (link to Zenodo) # URL to the smallest model distribution containing this model (link to Zenodo)
dist_url: str dist_url: str
type: str type: str

View file

@ -29,16 +29,6 @@
"type": "boolean", "type": "boolean",
"default": true, "default": true,
"description": "Try to detect all element subtypes, including drop-caps and headings" "description": "Try to detect all element subtypes, including drop-caps and headings"
},
"light_version": {
"type": "boolean",
"default": true,
"description": "Try to detect all element subtypes in light version (faster+simpler method for main region detection and deskewing)"
},
"textline_light": {
"type": "boolean",
"default": true,
"description": "Light version need textline light. If this parameter set to true, this tool will try to return contoure of textlines instead of rectangle bounding box of textline with a faster method."
}, },
"tables": { "tables": {
"type": "boolean", "type": "boolean",

View file

@ -18,9 +18,6 @@ class EynollahProcessor(Processor):
def setup(self) -> None: def setup(self) -> None:
assert self.parameter assert self.parameter
if self.parameter['textline_light'] != self.parameter['light_version']:
raise ValueError("Error: You must set or unset both parameter 'textline_light' (to enable light textline detection), "
"and parameter 'light_version' (faster+simpler method for main region detection and deskewing)")
model_zoo = EynollahModelZoo(basedir=self.parameter['models']) model_zoo = EynollahModelZoo(basedir=self.parameter['models'])
self.eynollah = Eynollah( self.eynollah = Eynollah(
model_zoo=model_zoo, model_zoo=model_zoo,
@ -29,8 +26,6 @@ class EynollahProcessor(Processor):
right2left=self.parameter['right_to_left'], right2left=self.parameter['right_to_left'],
reading_order_machine_based=self.parameter['reading_order_machine_based'], reading_order_machine_based=self.parameter['reading_order_machine_based'],
ignore_page_extraction=self.parameter['ignore_page_extraction'], ignore_page_extraction=self.parameter['ignore_page_extraction'],
light_version=self.parameter['light_version'],
textline_light=self.parameter['textline_light'],
full_layout=self.parameter['full_layout'], full_layout=self.parameter['full_layout'],
allow_scaling=self.parameter['allow_scaling'], allow_scaling=self.parameter['allow_scaling'],
headers_off=self.parameter['headers_off'], headers_off=self.parameter['headers_off'],
@ -93,7 +88,6 @@ class EynollahProcessor(Processor):
dir_out=None, dir_out=None,
image_filename=image_filename, image_filename=image_filename,
curved_line=self.eynollah.curved_line, curved_line=self.eynollah.curved_line,
textline_light=self.eynollah.textline_light,
pcgts=pcgts) pcgts=pcgts)
self.eynollah.run_single() self.eynollah.run_single()
return result return result

View file

@ -19,7 +19,6 @@ def adhere_drop_capital_region_into_corresponding_textline(
all_found_textline_polygons_h, all_found_textline_polygons_h,
kernel=None, kernel=None,
curved_line=False, curved_line=False,
textline_light=False,
): ):
# print(np.shape(all_found_textline_polygons),np.shape(all_found_textline_polygons[3]),'all_found_textline_polygonsshape') # print(np.shape(all_found_textline_polygons),np.shape(all_found_textline_polygons[3]),'all_found_textline_polygonsshape')
# print(all_found_textline_polygons[3]) # print(all_found_textline_polygons[3])
@ -79,7 +78,7 @@ def adhere_drop_capital_region_into_corresponding_textline(
# region_with_intersected_drop=region_with_intersected_drop/3 # region_with_intersected_drop=region_with_intersected_drop/3
region_with_intersected_drop = region_with_intersected_drop.astype(np.uint8) region_with_intersected_drop = region_with_intersected_drop.astype(np.uint8)
# print(np.unique(img_con_all_copy[:,:,0])) # print(np.unique(img_con_all_copy[:,:,0]))
if curved_line or textline_light: if curved_line:
if len(region_with_intersected_drop) > 1: if len(region_with_intersected_drop) > 1:
sum_pixels_of_intersection = [] sum_pixels_of_intersection = []

View file

@ -6,7 +6,7 @@ from .contour import find_new_features_of_contours, return_contours_of_intereste
from .resize import resize_image from .resize import resize_image
from .rotate import rotate_image from .rotate import rotate_image
def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_version=False, kernel=None): def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=None):
mask_marginals=np.zeros((text_with_lines.shape[0],text_with_lines.shape[1])) mask_marginals=np.zeros((text_with_lines.shape[0],text_with_lines.shape[1]))
mask_marginals=mask_marginals.astype(np.uint8) mask_marginals=mask_marginals.astype(np.uint8)
@ -27,9 +27,8 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
text_with_lines=resize_image(text_with_lines,text_with_lines_eroded.shape[0],text_with_lines_eroded.shape[1]) text_with_lines=resize_image(text_with_lines,text_with_lines_eroded.shape[0],text_with_lines_eroded.shape[1])
if light_version: kernel_hor = np.ones((1, 5), dtype=np.uint8)
kernel_hor = np.ones((1, 5), dtype=np.uint8) text_with_lines = cv2.erode(text_with_lines,kernel_hor,iterations=6)
text_with_lines = cv2.erode(text_with_lines,kernel_hor,iterations=6)
text_with_lines_y=text_with_lines.sum(axis=0) text_with_lines_y=text_with_lines.sum(axis=0)
text_with_lines_y_eroded=text_with_lines_eroded.sum(axis=0) text_with_lines_y_eroded=text_with_lines_eroded.sum(axis=0)
@ -43,10 +42,7 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
elif thickness_along_y_percent>=30 and thickness_along_y_percent<50: elif thickness_along_y_percent>=30 and thickness_along_y_percent<50:
min_textline_thickness=20 min_textline_thickness=20
else: else:
if light_version: min_textline_thickness=45
min_textline_thickness=45
else:
min_textline_thickness=40
if thickness_along_y_percent>=14: if thickness_along_y_percent>=14:
@ -128,92 +124,39 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_ve
if max_point_of_right_marginal>=text_regions.shape[1]: if max_point_of_right_marginal>=text_regions.shape[1]:
max_point_of_right_marginal=text_regions.shape[1]-1 max_point_of_right_marginal=text_regions.shape[1]-1
if light_version: text_regions_org = np.copy(text_regions)
text_regions_org = np.copy(text_regions) text_regions[text_regions[:,:]==1]=4
text_regions[text_regions[:,:]==1]=4
pixel_img=4 pixel_img=4
min_area_text=0.00001 min_area_text=0.00001
polygon_mask_marginals_rotated = return_contours_of_interested_region(mask_marginals,1,min_area_text) polygon_mask_marginals_rotated = return_contours_of_interested_region(mask_marginals,1,min_area_text)
polygon_mask_marginals_rotated = polygon_mask_marginals_rotated[0] polygon_mask_marginals_rotated = polygon_mask_marginals_rotated[0]
polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text) polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text)
cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals) cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals)
text_regions[(text_regions[:,:]==4)]=1 text_regions[(text_regions[:,:]==4)]=1
marginlas_should_be_main_text=[] marginlas_should_be_main_text=[]
x_min_marginals_left=[] x_min_marginals_left=[]
x_min_marginals_right=[] x_min_marginals_right=[]
for i in range(len(cx_text_only)): for i in range(len(cx_text_only)):
results = cv2.pointPolygonTest(polygon_mask_marginals_rotated, (cx_text_only[i], cy_text_only[i]), False) results = cv2.pointPolygonTest(polygon_mask_marginals_rotated, (cx_text_only[i], cy_text_only[i]), False)
if results == -1: if results == -1:
marginlas_should_be_main_text.append(polygons_of_marginals[i]) marginlas_should_be_main_text.append(polygons_of_marginals[i])
text_regions_org=cv2.fillPoly(text_regions_org, pts =marginlas_should_be_main_text, color=(4,4)) text_regions_org=cv2.fillPoly(text_regions_org, pts =marginlas_should_be_main_text, color=(4,4))
text_regions = np.copy(text_regions_org) text_regions = np.copy(text_regions_org)
else:
text_regions[(mask_marginals_rotated[:,:]!=1) & (text_regions[:,:]==1)]=4
pixel_img=4
min_area_text=0.00001
polygons_of_marginals=return_contours_of_interested_region(text_regions,pixel_img,min_area_text)
cx_text_only,cy_text_only ,x_min_text_only,x_max_text_only, y_min_text_only ,y_max_text_only,y_cor_x_min_main=find_new_features_of_contours(polygons_of_marginals)
text_regions[(text_regions[:,:]==4)]=1
marginlas_should_be_main_text=[]
x_min_marginals_left=[]
x_min_marginals_right=[]
for i in range(len(cx_text_only)):
x_width_mar=abs(x_min_text_only[i]-x_max_text_only[i])
y_height_mar=abs(y_min_text_only[i]-y_max_text_only[i])
if x_width_mar>16 and y_height_mar/x_width_mar<18:
marginlas_should_be_main_text.append(polygons_of_marginals[i])
if x_min_text_only[i]<(mid_point-one_third_left):
x_min_marginals_left_new=x_min_text_only[i]
if len(x_min_marginals_left)==0:
x_min_marginals_left.append(x_min_marginals_left_new)
else:
x_min_marginals_left[0]=min(x_min_marginals_left[0],x_min_marginals_left_new)
else:
x_min_marginals_right_new=x_min_text_only[i]
if len(x_min_marginals_right)==0:
x_min_marginals_right.append(x_min_marginals_right_new)
else:
x_min_marginals_right[0]=min(x_min_marginals_right[0],x_min_marginals_right_new)
if len(x_min_marginals_left)==0:
x_min_marginals_left=[0]
if len(x_min_marginals_right)==0:
x_min_marginals_right=[text_regions.shape[1]-1]
text_regions=cv2.fillPoly(text_regions, pts =marginlas_should_be_main_text, color=(4,4))
#text_regions[:,:int(x_min_marginals_left[0])][text_regions[:,:int(x_min_marginals_left[0])]==1]=0
#text_regions[:,int(x_min_marginals_right[0]):][text_regions[:,int(x_min_marginals_right[0]):]==1]=0
text_regions[:,:int(min_point_of_left_marginal)][text_regions[:,:int(min_point_of_left_marginal)]==1]=0
text_regions[:,int(max_point_of_right_marginal):][text_regions[:,int(max_point_of_right_marginal):]==1]=0
###text_regions[:,0:point_left][text_regions[:,0:point_left]==1]=4 ###text_regions[:,0:point_left][text_regions[:,0:point_left]==1]=4

View file

@ -1748,7 +1748,7 @@ def do_work_of_slopes_new_curved(
@wrap_ndarray_shared(kw='textline_mask_tot_ea') @wrap_ndarray_shared(kw='textline_mask_tot_ea')
def do_work_of_slopes_new_light( def do_work_of_slopes_new_light(
box_text, contour, contour_par, box_text, contour, contour_par,
textline_mask_tot_ea=None, slope_deskew=0, textline_light=True, textline_mask_tot_ea=None, slope_deskew=0,
logger=None logger=None
): ):
if logger is None: if logger is None:
@ -1765,16 +1765,10 @@ def do_work_of_slopes_new_light(
mask_only_con_region = np.zeros(textline_mask_tot_ea.shape) mask_only_con_region = np.zeros(textline_mask_tot_ea.shape)
mask_only_con_region = cv2.fillPoly(mask_only_con_region, pts=[contour_par], color=(1, 1, 1)) mask_only_con_region = cv2.fillPoly(mask_only_con_region, pts=[contour_par], color=(1, 1, 1))
if textline_light: all_text_region_raw = np.copy(textline_mask_tot_ea)
all_text_region_raw = np.copy(textline_mask_tot_ea) all_text_region_raw[mask_only_con_region == 0] = 0
all_text_region_raw[mask_only_con_region == 0] = 0 cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(all_text_region_raw)
cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(all_text_region_raw) cnt_clean_rot = filter_contours_area_of_image(all_text_region_raw, cnt_clean_rot_raw, hir_on_cnt_clean_rot,
cnt_clean_rot = filter_contours_area_of_image(all_text_region_raw, cnt_clean_rot_raw, hir_on_cnt_clean_rot, max_area=1, min_area=0.00001)
max_area=1, min_area=0.00001)
else:
all_text_region_raw = np.copy(textline_mask_tot_ea[y: y + h, x: x + w])
mask_only_con_region = mask_only_con_region[y: y + h, x: x + w]
all_text_region_raw[mask_only_con_region == 0] = 0
cnt_clean_rot = textline_contours_postprocessing(all_text_region_raw, slope_deskew, contour_par, box_text)
return cnt_clean_rot, crop_coor, slope_deskew return cnt_clean_rot, crop_coor, slope_deskew

View file

@ -379,7 +379,6 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
all_box_coord, all_box_coord,
prediction_model, prediction_model,
b_s_ocr, num_to_char, b_s_ocr, num_to_char,
textline_light=False,
curved_line=False): curved_line=False):
max_len = 512 max_len = 512
padding_token = 299 padding_token = 299
@ -404,7 +403,7 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
else: else:
for indexing2, ind_poly in enumerate(ind_poly_first): for indexing2, ind_poly in enumerate(ind_poly_first):
cropped_lines_region_indexer.append(indexer_text_region) cropped_lines_region_indexer.append(indexer_text_region)
if not (textline_light or curved_line): if not curved_line:
ind_poly = copy.deepcopy(ind_poly) ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing] box_ind = all_box_coord[indexing]

View file

@ -23,14 +23,13 @@ import numpy as np
class EynollahXmlWriter: class EynollahXmlWriter:
def __init__(self, *, dir_out, image_filename, curved_line,textline_light, pcgts=None): def __init__(self, *, dir_out, image_filename, curved_line, pcgts=None):
self.logger = logging.getLogger('eynollah.writer') self.logger = logging.getLogger('eynollah.writer')
self.counter = EynollahIdCounter() self.counter = EynollahIdCounter()
self.dir_out = dir_out self.dir_out = dir_out
self.image_filename = image_filename self.image_filename = image_filename
self.output_filename = os.path.join(self.dir_out or "", self.image_filename_stem) + ".xml" self.output_filename = os.path.join(self.dir_out or "", self.image_filename_stem) + ".xml"
self.curved_line = curved_line self.curved_line = curved_line
self.textline_light = textline_light
self.pcgts = pcgts self.pcgts = pcgts
self.scale_x: Optional[float] = None # XXX set outside __init__ self.scale_x: Optional[float] = None # XXX set outside __init__
self.scale_y: Optional[float] = None # XXX set outside __init__ self.scale_y: Optional[float] = None # XXX set outside __init__
@ -73,8 +72,8 @@ class EynollahXmlWriter:
point = point[0] point = point[0]
point_x = point[0] + page_coord[2] point_x = point[0] + page_coord[2]
point_y = point[1] + page_coord[0] point_y = point[1] + page_coord[0]
# FIXME: or actually... not self.textline_light and not self.curved_line or np.abs(slopes[region_idx]) > 45? # FIXME: or actually... not self.curved_line or np.abs(slopes[region_idx]) > 45?
if not self.textline_light and not (self.curved_line and np.abs(slopes[region_idx]) <= 45): if not (self.curved_line and np.abs(slopes[region_idx]) <= 45):
point_x += region_bboxes[2] point_x += region_bboxes[2]
point_y += region_bboxes[0] point_y += region_bboxes[0]
point_x = max(0, int(point_x / self.scale_x)) point_x = max(0, int(point_x / self.scale_x))

View file

@ -9,8 +9,6 @@ from ocrd_models.constants import NAMESPACES as NS
#["--allow_scaling", "--curved-line"], #["--allow_scaling", "--curved-line"],
["--allow_scaling", "--curved-line", "--full-layout"], ["--allow_scaling", "--curved-line", "--full-layout"],
["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"], ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"],
["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based",
"--textline_light", "--light_version"],
# -ep ... # -ep ...
# -eoi ... # -eoi ...
# --skip_layout_and_reading_order # --skip_layout_and_reading_order
@ -47,7 +45,6 @@ def test_run_eynollah_layout_filename(
[ [
["--tables"], ["--tables"],
["--tables", "--full-layout"], ["--tables", "--full-layout"],
["--tables", "--full-layout", "--textline_light", "--light_version"],
], ids=str) ], ids=str)
def test_run_eynollah_layout_filename2( def test_run_eynollah_layout_filename2(
tmp_path, tmp_path,