mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-07-12 12:29:56 +02:00
Merge remote-tracking branch 'origin/main' into v3-api-release-foreal
(bad-ass difficult diff diffing)
This commit is contained in:
commit
108ce1f5a1
11 changed files with 1633 additions and 515 deletions
|
@ -324,6 +324,12 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
help="directory of images",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--dir_in_bin",
|
||||
"-dib",
|
||||
help="directory of binarized images. This should be given if you want to do prediction based on both rgb and bin images. And all bin images are png files",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--out",
|
||||
"-o",
|
||||
|
@ -337,6 +343,12 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
help="directory of xmls",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--dir_out_image_text",
|
||||
"-doit",
|
||||
help="directory of images with predicted text",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--model",
|
||||
"-m",
|
||||
|
@ -362,6 +374,18 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
is_flag=True,
|
||||
help="if this parameter set to true, cropped textline images will not be masked with textline contour.",
|
||||
)
|
||||
@click.option(
|
||||
"--draw_texts_on_image",
|
||||
"-dtoi/-ndtoi",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, the predicted texts will be displayed on an image.",
|
||||
)
|
||||
@click.option(
|
||||
"--prediction_with_both_of_rgb_and_bin",
|
||||
"-brb/-nbrb",
|
||||
is_flag=True,
|
||||
help="If this parameter is set to True, the prediction will be performed using both RGB and binary images. However, this does not necessarily improve results; it may be beneficial for certain document images.",
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
"-l",
|
||||
|
@ -369,18 +393,22 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
help="Override log level globally to this",
|
||||
)
|
||||
|
||||
def ocr(dir_in, out, dir_xmls, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, log_level):
|
||||
def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, log_level):
|
||||
initLogging()
|
||||
if log_level:
|
||||
getLogger('eynollah').setLevel(getLevelName(log_level))
|
||||
eynollah_ocr = Eynollah_ocr(
|
||||
dir_xmls=dir_xmls,
|
||||
dir_out_image_text=dir_out_image_text,
|
||||
dir_in=dir_in,
|
||||
dir_in_bin=dir_in_bin,
|
||||
dir_out=out,
|
||||
dir_models=model,
|
||||
tr_ocr=tr_ocr,
|
||||
export_textline_images_and_text=export_textline_images_and_text,
|
||||
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
|
||||
draw_texts_on_image=draw_texts_on_image,
|
||||
prediction_with_both_of_rgb_and_bin=prediction_with_both_of_rgb_and_bin,
|
||||
)
|
||||
eynollah_ocr.run()
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -8,7 +8,6 @@ except ImportError:
|
|||
import numpy as np
|
||||
from shapely import geometry
|
||||
import cv2
|
||||
import imutils
|
||||
from scipy.signal import find_peaks
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
|
@ -28,7 +27,7 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
y_sep=[]
|
||||
y_diff=[]
|
||||
new_main_sep_y=[]
|
||||
|
||||
|
||||
indexer=0
|
||||
for i in range(len(x_min_hor_some)):
|
||||
starting=x_min_hor_some[i]-peak_points
|
||||
|
@ -36,34 +35,34 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
min_start=np.argmin(starting)
|
||||
ending=peak_points-x_max_hor_some[i]
|
||||
len_ending_neg=len(ending[ending<=0])
|
||||
|
||||
|
||||
ending=ending[ending>0]
|
||||
max_end=np.argmin(ending)+len_ending_neg
|
||||
|
||||
|
||||
if (max_end-min_start)>=2:
|
||||
if (max_end-min_start)==(len(peak_points)-1):
|
||||
new_main_sep_y.append(indexer)
|
||||
|
||||
|
||||
#print((max_end-min_start),len(peak_points),'(max_end-min_start)')
|
||||
y_sep.append(cy_hor_some[i])
|
||||
y_diff.append(cy_hor_diff[i])
|
||||
x_end.append(max_end)
|
||||
|
||||
|
||||
x_start.append( min_start)
|
||||
|
||||
|
||||
len_sep.append(max_end-min_start)
|
||||
if max_end==min_start+1:
|
||||
kind.append(0)
|
||||
else:
|
||||
kind.append(1)
|
||||
|
||||
|
||||
indexer+=1
|
||||
|
||||
x_start_returned = np.array(x_start, dtype=int)
|
||||
x_end_returned = np.array(x_end, dtype=int)
|
||||
y_sep_returned = np.array(y_sep, dtype=int)
|
||||
y_diff_returned = np.array(y_diff, dtype=int)
|
||||
|
||||
|
||||
all_args_uniq = contours_in_same_horizon(y_sep_returned)
|
||||
args_to_be_unified=[]
|
||||
y_unified=[]
|
||||
|
@ -92,7 +91,7 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
y_diff_selected=np.max(y_diff_same_hor)
|
||||
x_s_selected=np.min(x_s_same_hor)
|
||||
x_e_selected=np.max(x_e_same_hor)
|
||||
|
||||
|
||||
x_s_unified.append(x_s_selected)
|
||||
x_e_unified.append(x_e_selected)
|
||||
y_unified.append(y_selected)
|
||||
|
@ -106,56 +105,56 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
|
||||
args_lines_not_unified=list( set(range(len(y_sep_returned)))-set(args_to_be_unified) )
|
||||
#print(args_lines_not_unified,'args_lines_not_unified')
|
||||
|
||||
|
||||
x_start_returned_not_unified=list( np.array(x_start_returned)[args_lines_not_unified] )
|
||||
x_end_returned_not_unified=list( np.array(x_end_returned)[args_lines_not_unified] )
|
||||
y_sep_returned_not_unified=list (np.array(y_sep_returned)[args_lines_not_unified] )
|
||||
y_diff_returned_not_unified=list (np.array(y_diff_returned)[args_lines_not_unified] )
|
||||
|
||||
|
||||
for dv in range(len(y_unified)):
|
||||
y_sep_returned_not_unified.append(y_unified[dv])
|
||||
y_diff_returned_not_unified.append(y_diff_unified[dv])
|
||||
x_start_returned_not_unified.append(x_s_unified[dv])
|
||||
x_end_returned_not_unified.append(x_e_unified[dv])
|
||||
|
||||
|
||||
#print(y_sep_returned,'y_sep_returned')
|
||||
#print(x_start_returned,'x_start_returned')
|
||||
#print(x_end_returned,'x_end_returned')
|
||||
|
||||
|
||||
x_start_returned = np.array(x_start_returned_not_unified, dtype=int)
|
||||
x_end_returned = np.array(x_end_returned_not_unified, dtype=int)
|
||||
y_sep_returned = np.array(y_sep_returned_not_unified, dtype=int)
|
||||
y_diff_returned = np.array(y_diff_returned_not_unified, dtype=int)
|
||||
|
||||
|
||||
#print(y_sep_returned,'y_sep_returned2')
|
||||
#print(x_start_returned,'x_start_returned2')
|
||||
#print(x_end_returned,'x_end_returned2')
|
||||
#print(new_main_sep_y,'new_main_sep_y')
|
||||
|
||||
|
||||
#print(x_start,'x_start')
|
||||
#print(x_end,'x_end')
|
||||
if len(new_main_sep_y)>0:
|
||||
|
||||
|
||||
min_ys=np.min(y_sep)
|
||||
max_ys=np.max(y_sep)
|
||||
|
||||
|
||||
y_mains=[]
|
||||
y_mains.append(min_ys)
|
||||
y_mains_sep_ohne_grenzen=[]
|
||||
|
||||
|
||||
for ii in range(len(new_main_sep_y)):
|
||||
y_mains.append(y_sep[new_main_sep_y[ii]])
|
||||
y_mains_sep_ohne_grenzen.append(y_sep[new_main_sep_y[ii]])
|
||||
|
||||
|
||||
y_mains.append(max_ys)
|
||||
|
||||
|
||||
y_mains_sorted=np.sort(y_mains)
|
||||
diff=np.diff(y_mains_sorted)
|
||||
argm=np.argmax(diff)
|
||||
|
||||
|
||||
y_min_new=y_mains_sorted[argm]
|
||||
y_max_new=y_mains_sorted[argm+1]
|
||||
|
||||
|
||||
#print(y_min_new,'y_min_new')
|
||||
#print(y_max_new,'y_max_new')
|
||||
#print(y_sep[new_main_sep_y[0]],y_sep,'yseps')
|
||||
|
@ -192,7 +191,7 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
#print(x_start,'x_start')
|
||||
#print(x_end,'x_end')
|
||||
#print(len_sep)
|
||||
|
||||
|
||||
deleted=[]
|
||||
for i in range(len(x_start)-1):
|
||||
nodes_i=set(range(x_start[i],x_end[i]+1))
|
||||
|
@ -200,10 +199,10 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
if nodes_i==set(range(x_start[j],x_end[j]+1)):
|
||||
deleted.append(j)
|
||||
#print(np.unique(deleted))
|
||||
|
||||
|
||||
remained_sep_indexes=set(range(len(x_start)))-set(np.unique(deleted) )
|
||||
#print(remained_sep_indexes,'remained_sep_indexes')
|
||||
mother=[]#if it has mother
|
||||
mother=[]#if it has mother
|
||||
child=[]
|
||||
for index_i in remained_sep_indexes:
|
||||
have_mother=0
|
||||
|
@ -217,14 +216,14 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
have_child=1
|
||||
mother.append(have_mother)
|
||||
child.append(have_child)
|
||||
|
||||
|
||||
#print(mother,'mother')
|
||||
#print(len(remained_sep_indexes))
|
||||
#print(len(remained_sep_indexes),len(x_start),len(x_end),len(y_sep),'lens')
|
||||
y_lines_without_mother=[]
|
||||
x_start_without_mother=[]
|
||||
x_end_without_mother=[]
|
||||
|
||||
|
||||
y_lines_with_child_without_mother=[]
|
||||
x_start_with_child_without_mother=[]
|
||||
x_end_with_child_without_mother=[]
|
||||
|
@ -237,7 +236,7 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
x_start = np.array(x_start)
|
||||
x_end = np.array(x_end)
|
||||
y_sep = np.array(y_sep)
|
||||
|
||||
|
||||
if len(remained_sep_indexes)>1:
|
||||
#print(np.array(remained_sep_indexes),'np.array(remained_sep_indexes)')
|
||||
#print(np.array(mother),'mother')
|
||||
|
@ -245,7 +244,7 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
remained_sep_indexes_with_child_without_mother = remained_sep_indexes[(mother==0) & (child==1)]
|
||||
#print(remained_sep_indexes_without_mother,'remained_sep_indexes_without_mother')
|
||||
#print(remained_sep_indexes_without_mother,'remained_sep_indexes_without_mother')
|
||||
|
||||
|
||||
x_end_with_child_without_mother = x_end[remained_sep_indexes_with_child_without_mother]
|
||||
x_start_with_child_without_mother = x_start[remained_sep_indexes_with_child_without_mother]
|
||||
y_lines_with_child_without_mother = y_sep[remained_sep_indexes_with_child_without_mother]
|
||||
|
@ -254,7 +253,7 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
x_end_without_mother = x_end[remained_sep_indexes_without_mother]
|
||||
x_start_without_mother = x_start[remained_sep_indexes_without_mother]
|
||||
y_lines_without_mother = y_sep[remained_sep_indexes_without_mother]
|
||||
|
||||
|
||||
if len(remained_sep_indexes_without_mother)>=2:
|
||||
for i in range(len(remained_sep_indexes_without_mother)-1):
|
||||
nodes_i=set(range(x_start[remained_sep_indexes_without_mother[i]],
|
||||
|
@ -275,16 +274,16 @@ def return_x_start_end_mothers_childs_and_type_of_reading_order(
|
|||
#print(y_lines_with_child_without_mother,'y_lines_with_child_without_mother')
|
||||
#print(x_start_with_child_without_mother,'x_start_with_child_without_mother')
|
||||
#print(x_end_with_child_without_mother,'x_end_with_hild_without_mother')
|
||||
|
||||
|
||||
len_sep_with_child = len(child[child==1])
|
||||
|
||||
|
||||
#print(len_sep_with_child,'len_sep_with_child')
|
||||
there_is_sep_with_child = 0
|
||||
if len_sep_with_child >= 1:
|
||||
there_is_sep_with_child = 1
|
||||
#print(all_args_uniq,'all_args_uniq')
|
||||
#print(args_to_be_unified,'args_to_be_unified')
|
||||
|
||||
|
||||
return (reading_orther_type,
|
||||
x_start_returned,
|
||||
x_end_returned,
|
||||
|
@ -433,13 +432,13 @@ def find_num_col(regions_without_separators, num_col_classifier, tables, multipl
|
|||
interest_neg_fin = interest_neg[(interest_neg < grenze)]
|
||||
peaks_neg_fin = peaks_neg[(interest_neg < grenze)]
|
||||
# interest_neg_fin=interest_neg[(interest_neg<grenze)]
|
||||
|
||||
|
||||
if not tables:
|
||||
if ( num_col_classifier - ( (len(interest_neg_fin))+1 ) ) >= 3:
|
||||
index_sort_interest_neg_fin= np.argsort(interest_neg_fin)
|
||||
peaks_neg_sorted = np.array(peaks_neg)[index_sort_interest_neg_fin]
|
||||
interest_neg_fin_sorted = np.array(interest_neg_fin)[index_sort_interest_neg_fin]
|
||||
|
||||
|
||||
if len(index_sort_interest_neg_fin)>=num_col_classifier:
|
||||
peaks_neg_fin = list( peaks_neg_sorted[:num_col_classifier] )
|
||||
interest_neg_fin = list( interest_neg_fin_sorted[:num_col_classifier] )
|
||||
|
@ -846,11 +845,11 @@ def putt_bb_of_drop_capitals_of_model_in_patches_in_layout(layout_in_patch, drop
|
|||
box0 = box + (0,)
|
||||
mask_of_drop_cpaital_in_early_layout = np.zeros((text_regions_p.shape[0], text_regions_p.shape[1]))
|
||||
mask_of_drop_cpaital_in_early_layout[box] = text_regions_p[box]
|
||||
|
||||
|
||||
all_drop_capital_pixels_which_is_text_in_early_lo = np.sum(mask_of_drop_cpaital_in_early_layout[box]==1)
|
||||
mask_of_drop_cpaital_in_early_layout[box] = 1
|
||||
all_drop_capital_pixels = np.sum(mask_of_drop_cpaital_in_early_layout==1)
|
||||
|
||||
|
||||
percent_text_to_all_in_drop = all_drop_capital_pixels_which_is_text_in_early_lo / float(all_drop_capital_pixels)
|
||||
if (areas_cnt_text[jj] * float(drop_only.shape[0] * drop_only.shape[1]) / float(w * h) > 0.6 and
|
||||
percent_text_to_all_in_drop >= 0.3):
|
||||
|
@ -868,7 +867,7 @@ def check_any_text_region_in_model_one_is_main_or_header(
|
|||
contours_only_text_parent,
|
||||
all_box_coord, all_found_textline_polygons,
|
||||
slopes,
|
||||
contours_only_text_parent_d_ordered):
|
||||
contours_only_text_parent_d_ordered, conf_contours):
|
||||
|
||||
cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin = \
|
||||
find_new_features_of_contours(contours_only_text_parent)
|
||||
|
@ -888,6 +887,9 @@ def check_any_text_region_in_model_one_is_main_or_header(
|
|||
contours_only_text_parent_main=[]
|
||||
contours_only_text_parent_head=[]
|
||||
|
||||
conf_contours_main=[]
|
||||
conf_contours_head=[]
|
||||
|
||||
contours_only_text_parent_main_d=[]
|
||||
contours_only_text_parent_head_d=[]
|
||||
|
||||
|
@ -908,9 +910,11 @@ def check_any_text_region_in_model_one_is_main_or_header(
|
|||
all_box_coord_head.append(all_box_coord[ii])
|
||||
slopes_head.append(slopes[ii])
|
||||
all_found_textline_polygons_head.append(all_found_textline_polygons[ii])
|
||||
conf_contours_head.append(None)
|
||||
else:
|
||||
regions_model_1[:,:][(regions_model_1[:,:]==1) & (img[:,:,0]==255) ]=1
|
||||
contours_only_text_parent_main.append(con)
|
||||
conf_contours_main.append(conf_contours[ii])
|
||||
if contours_only_text_parent_d_ordered is not None:
|
||||
contours_only_text_parent_main_d.append(contours_only_text_parent_d_ordered[ii])
|
||||
all_box_coord_main.append(all_box_coord[ii])
|
||||
|
@ -929,14 +933,17 @@ def check_any_text_region_in_model_one_is_main_or_header(
|
|||
slopes_main,
|
||||
slopes_head,
|
||||
contours_only_text_parent_main_d,
|
||||
contours_only_text_parent_head_d)
|
||||
contours_only_text_parent_head_d,
|
||||
conf_contours_main,
|
||||
conf_contours_head)
|
||||
|
||||
def check_any_text_region_in_model_one_is_main_or_header_light(
|
||||
regions_model_1, regions_model_full,
|
||||
contours_only_text_parent,
|
||||
all_box_coord, all_found_textline_polygons,
|
||||
slopes,
|
||||
contours_only_text_parent_d_ordered):
|
||||
contours_only_text_parent_d_ordered,
|
||||
conf_contours):
|
||||
|
||||
### to make it faster
|
||||
h_o = regions_model_1.shape[0]
|
||||
|
@ -969,6 +976,9 @@ def check_any_text_region_in_model_one_is_main_or_header_light(
|
|||
contours_only_text_parent_main=[]
|
||||
contours_only_text_parent_head=[]
|
||||
|
||||
conf_contours_main=[]
|
||||
conf_contours_head=[]
|
||||
|
||||
contours_only_text_parent_main_d=[]
|
||||
contours_only_text_parent_head_d=[]
|
||||
|
||||
|
@ -990,9 +1000,11 @@ def check_any_text_region_in_model_one_is_main_or_header_light(
|
|||
all_box_coord_head.append(all_box_coord[ii])
|
||||
slopes_head.append(slopes[ii])
|
||||
all_found_textline_polygons_head.append(all_found_textline_polygons[ii])
|
||||
conf_contours_head.append(None)
|
||||
else:
|
||||
regions_model_1[:,:][(regions_model_1[:,:]==1) & (img[:,:,0]==255) ]=1
|
||||
contours_only_text_parent_main.append(con)
|
||||
conf_contours_main.append(conf_contours[ii])
|
||||
if contours_only_text_parent_d_ordered is not None:
|
||||
contours_only_text_parent_main_d.append(contours_only_text_parent_d_ordered[ii])
|
||||
all_box_coord_main.append(all_box_coord[ii])
|
||||
|
@ -1009,7 +1021,7 @@ def check_any_text_region_in_model_one_is_main_or_header_light(
|
|||
contours_only_text_parent_head = [(i * zoom).astype(int) for i in contours_only_text_parent_head]
|
||||
contours_only_text_parent_main = [(i * zoom).astype(int) for i in contours_only_text_parent_main]
|
||||
###
|
||||
|
||||
|
||||
return (regions_model_1,
|
||||
contours_only_text_parent_main,
|
||||
contours_only_text_parent_head,
|
||||
|
@ -1020,7 +1032,9 @@ def check_any_text_region_in_model_one_is_main_or_header_light(
|
|||
slopes_main,
|
||||
slopes_head,
|
||||
contours_only_text_parent_main_d,
|
||||
contours_only_text_parent_head_d)
|
||||
contours_only_text_parent_head_d,
|
||||
conf_contours_main,
|
||||
conf_contours_head)
|
||||
|
||||
def small_textlines_to_parent_adherence2(textlines_con, textline_iamge, num_col):
|
||||
# print(textlines_con)
|
||||
|
@ -1317,11 +1331,11 @@ def combine_hor_lines_and_delete_cross_points_and_get_lines_features_back_new(
|
|||
imgray = cv2.cvtColor(img_in_hor, cv2.COLOR_BGR2GRAY)
|
||||
ret, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
contours_lines_hor,hierarchy=cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
|
||||
slope_lines_hor, dist_x_hor, x_min_main_hor, x_max_main_hor, cy_main_hor, _, _, _, _ = \
|
||||
find_features_of_lines(contours_lines_hor)
|
||||
x_width_smaller_than_acolumn_width=img_in_hor.shape[1]/float(num_col_classifier+1.)
|
||||
|
||||
|
||||
len_lines_bigger_than_x_width_smaller_than_acolumn_width=len( dist_x_hor[dist_x_hor>=x_width_smaller_than_acolumn_width] )
|
||||
len_lines_bigger_than_x_width_smaller_than_acolumn_width_per_column=int(len_lines_bigger_than_x_width_smaller_than_acolumn_width /
|
||||
float(num_col_classifier))
|
||||
|
@ -1339,7 +1353,7 @@ def combine_hor_lines_and_delete_cross_points_and_get_lines_features_back_new(
|
|||
some_cy=cy_main_hor[all_args_uniq[dd]]
|
||||
some_x_min=x_min_main_hor[all_args_uniq[dd]]
|
||||
some_x_max=x_max_main_hor[all_args_uniq[dd]]
|
||||
|
||||
|
||||
#img_in=np.zeros(separators_closeup_n[:,:,2].shape)
|
||||
#print(img_p_in_ver.shape[1],some_x_max-some_x_min,'xdiff')
|
||||
diff_x_some=some_x_max-some_x_min
|
||||
|
@ -1352,7 +1366,7 @@ def combine_hor_lines_and_delete_cross_points_and_get_lines_features_back_new(
|
|||
int(np.max(some_x_max)) ]=1
|
||||
sum_dis=dist_x_hor[some_args].sum()
|
||||
diff_max_min_uniques=np.max(x_max_main_hor[some_args])-np.min(x_min_main_hor[some_args])
|
||||
|
||||
|
||||
if (diff_max_min_uniques > sum_dis and
|
||||
sum_dis / float(diff_max_min_uniques) > 0.85 and
|
||||
diff_max_min_uniques / float(img_p_in_ver.shape[1]) > 0.85 and
|
||||
|
@ -1371,7 +1385,7 @@ def combine_hor_lines_and_delete_cross_points_and_get_lines_features_back_new(
|
|||
else:
|
||||
img_p_in=img_in_hor
|
||||
special_separators=[]
|
||||
|
||||
|
||||
img_p_in_ver[:,:,0][img_p_in_ver[:,:,0]==255]=1
|
||||
sep_ver_hor=img_p_in+img_p_in_ver
|
||||
sep_ver_hor_cross=(sep_ver_hor[:,:,0]==2)*1
|
||||
|
@ -1402,7 +1416,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
separators_closeup=( (region_pre_p[:,:,:]==pixel_lines))*1
|
||||
separators_closeup[0:110,:,:]=0
|
||||
separators_closeup[separators_closeup.shape[0]-150:,:,:]=0
|
||||
|
||||
|
||||
kernel = np.ones((5,5),np.uint8)
|
||||
separators_closeup=separators_closeup.astype(np.uint8)
|
||||
separators_closeup = cv2.dilate(separators_closeup,kernel,iterations = 1)
|
||||
|
@ -1420,7 +1434,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
gray_early=gray_early.astype(np.uint8)
|
||||
imgray_e = cv2.cvtColor(gray_early, cv2.COLOR_BGR2GRAY)
|
||||
ret_e, thresh_e = cv2.threshold(imgray_e, 0, 255, 0)
|
||||
|
||||
|
||||
contours_line_e,hierarchy_e=cv2.findContours(thresh_e,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||||
_, dist_xe, _, _, _, _, y_min_main, y_max_main, _ = \
|
||||
find_features_of_lines(contours_line_e)
|
||||
|
@ -1433,11 +1447,11 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
cnts_hor_e.append(contours_line_e[ce])
|
||||
figs_e=np.zeros(thresh_e.shape)
|
||||
figs_e=cv2.fillPoly(figs_e,pts=cnts_hor_e,color=(1,1,1))
|
||||
|
||||
|
||||
separators_closeup_n_binary=cv2.fillPoly(separators_closeup_n_binary, pts=cnts_hor_e, color=(0,0,0))
|
||||
gray = cv2.bitwise_not(separators_closeup_n_binary)
|
||||
gray=gray.astype(np.uint8)
|
||||
|
||||
|
||||
bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, \
|
||||
cv2.THRESH_BINARY, 15, -2)
|
||||
horizontal = np.copy(bw)
|
||||
|
@ -1455,7 +1469,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
horizontal = cv2.dilate(horizontal,kernel,iterations = 2)
|
||||
horizontal = cv2.erode(horizontal,kernel,iterations = 2)
|
||||
horizontal = cv2.fillPoly(horizontal, pts=cnts_hor_e, color=(255,255,255))
|
||||
|
||||
|
||||
rows = vertical.shape[0]
|
||||
verticalsize = rows // 30
|
||||
# Create structure element for extracting vertical lines through morphology operations
|
||||
|
@ -1468,16 +1482,16 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
horizontal, special_separators = \
|
||||
combine_hor_lines_and_delete_cross_points_and_get_lines_features_back_new(
|
||||
vertical, horizontal, num_col_classifier)
|
||||
|
||||
|
||||
separators_closeup_new[:,:][vertical[:,:]!=0]=1
|
||||
separators_closeup_new[:,:][horizontal[:,:]!=0]=1
|
||||
|
||||
|
||||
vertical=np.repeat(vertical[:, :, np.newaxis], 3, axis=2)
|
||||
vertical=vertical.astype(np.uint8)
|
||||
|
||||
|
||||
imgray = cv2.cvtColor(vertical, cv2.COLOR_BGR2GRAY)
|
||||
ret, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
|
||||
|
||||
contours_line_vers,hierarchy=cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||||
slope_lines, dist_x, x_min_main, x_max_main, cy_main, slope_lines_org, y_min_main, y_max_main, cx_main = \
|
||||
find_features_of_lines(contours_line_vers)
|
||||
|
@ -1492,7 +1506,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
cx_main_ver=cx_main[slope_lines==1]
|
||||
dist_y_ver=y_max_main_ver-y_min_main_ver
|
||||
len_y=separators_closeup.shape[0]/3.0
|
||||
|
||||
|
||||
horizontal=np.repeat(horizontal[:, :, np.newaxis], 3, axis=2)
|
||||
horizontal=horizontal.astype(np.uint8)
|
||||
imgray = cv2.cvtColor(horizontal, cv2.COLOR_BGR2GRAY)
|
||||
|
@ -1500,12 +1514,12 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
contours_line_hors,hierarchy=cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||||
slope_lines, dist_x, x_min_main, x_max_main, cy_main, slope_lines_org, y_min_main, y_max_main, cx_main = \
|
||||
find_features_of_lines(contours_line_hors)
|
||||
|
||||
|
||||
slope_lines_org_hor=slope_lines_org[slope_lines==0]
|
||||
args=np.arange(len(slope_lines))
|
||||
len_x=separators_closeup.shape[1]/5.0
|
||||
dist_y=np.abs(y_max_main-y_min_main)
|
||||
|
||||
|
||||
args_hor=args[slope_lines==0]
|
||||
dist_x_hor=dist_x[slope_lines==0]
|
||||
y_min_main_hor=y_min_main[slope_lines==0]
|
||||
|
@ -1524,7 +1538,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
dist_y_hor=dist_y_hor[dist_x_hor>=len_x/2.0]
|
||||
slope_lines_org_hor=slope_lines_org_hor[dist_x_hor>=len_x/2.0]
|
||||
dist_x_hor=dist_x_hor[dist_x_hor>=len_x/2.0]
|
||||
|
||||
|
||||
matrix_of_lines_ch=np.zeros((len(cy_main_hor)+len(cx_main_ver),10))
|
||||
matrix_of_lines_ch[:len(cy_main_hor),0]=args_hor
|
||||
matrix_of_lines_ch[len(cy_main_hor):,0]=args_ver
|
||||
|
@ -1543,14 +1557,14 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
matrix_of_lines_ch[:len(cy_main_hor),8]=dist_y_hor
|
||||
matrix_of_lines_ch[len(cy_main_hor):,8]=dist_y_ver
|
||||
matrix_of_lines_ch[len(cy_main_hor):,9]=1
|
||||
|
||||
|
||||
if contours_h is not None:
|
||||
_, dist_x_head, x_min_main_head, x_max_main_head, cy_main_head, _, y_min_main_head, y_max_main_head, _ = \
|
||||
find_features_of_lines(contours_h)
|
||||
matrix_l_n=np.zeros((matrix_of_lines_ch.shape[0]+len(cy_main_head),matrix_of_lines_ch.shape[1]))
|
||||
matrix_l_n[:matrix_of_lines_ch.shape[0],:]=np.copy(matrix_of_lines_ch[:,:])
|
||||
args_head=np.arange(len(cy_main_head)) + len(cy_main_hor)
|
||||
|
||||
|
||||
matrix_l_n[matrix_of_lines_ch.shape[0]:,0]=args_head
|
||||
matrix_l_n[matrix_of_lines_ch.shape[0]:,2]=x_min_main_head+30
|
||||
matrix_l_n[matrix_of_lines_ch.shape[0]:,3]=x_max_main_head-30
|
||||
|
@ -1560,7 +1574,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
matrix_l_n[matrix_of_lines_ch.shape[0]:,7]=y_max_main_head#y_min_main_head+1-8
|
||||
matrix_l_n[matrix_of_lines_ch.shape[0]:,8]=4
|
||||
matrix_of_lines_ch=np.copy(matrix_l_n)
|
||||
|
||||
|
||||
cy_main_splitters=cy_main_hor[(x_min_main_hor<=.16*region_pre_p.shape[1]) &
|
||||
(x_max_main_hor>=.84*region_pre_p.shape[1])]
|
||||
cy_main_splitters=np.array( list(cy_main_splitters)+list(special_separators))
|
||||
|
@ -1573,19 +1587,19 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
pass
|
||||
args_cy_splitter=np.argsort(cy_main_splitters)
|
||||
cy_main_splitters_sort=cy_main_splitters[args_cy_splitter]
|
||||
|
||||
|
||||
splitter_y_new=[]
|
||||
splitter_y_new.append(0)
|
||||
for i in range(len(cy_main_splitters_sort)):
|
||||
splitter_y_new.append( cy_main_splitters_sort[i] )
|
||||
splitter_y_new.append( cy_main_splitters_sort[i] )
|
||||
splitter_y_new.append(region_pre_p.shape[0])
|
||||
splitter_y_new_diff=np.diff(splitter_y_new)/float(region_pre_p.shape[0])*100
|
||||
|
||||
|
||||
args_big_parts=np.arange(len(splitter_y_new_diff))[ splitter_y_new_diff>22 ]
|
||||
|
||||
regions_without_separators=return_regions_without_separators(region_pre_p)
|
||||
length_y_threshold=regions_without_separators.shape[0]/4.0
|
||||
|
||||
|
||||
num_col_fin=0
|
||||
peaks_neg_fin_fin=[]
|
||||
for itiles in args_big_parts:
|
||||
|
@ -1600,15 +1614,15 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
if num_col>num_col_fin:
|
||||
num_col_fin=num_col
|
||||
peaks_neg_fin_fin=peaks_neg_fin
|
||||
|
||||
|
||||
if len(args_big_parts)==1 and (len(peaks_neg_fin_fin)+1)<num_col_classifier:
|
||||
peaks_neg_fin=find_num_col_by_vertical_lines(vertical)
|
||||
peaks_neg_fin=peaks_neg_fin[peaks_neg_fin>=500]
|
||||
peaks_neg_fin=peaks_neg_fin[peaks_neg_fin<=(vertical.shape[1]-500)]
|
||||
peaks_neg_fin_fin=peaks_neg_fin[:]
|
||||
|
||||
|
||||
return num_col_fin, peaks_neg_fin_fin,matrix_of_lines_ch,splitter_y_new,separators_closeup_n
|
||||
|
||||
|
||||
def return_boxes_of_images_by_order_of_reading_new(
|
||||
splitter_y_new, regions_without_separators,
|
||||
matrix_of_lines_ch,
|
||||
|
@ -1655,7 +1669,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
for p_n in peaks_neg_fin:
|
||||
peaks_neg_fin_early.append(p_n)
|
||||
peaks_neg_fin_early.append(regions_without_separators.shape[1]-1)
|
||||
|
||||
|
||||
#print(peaks_neg_fin_early,'burda2')
|
||||
peaks_neg_fin_rev=[]
|
||||
for i_n in range(len(peaks_neg_fin_early)-1):
|
||||
|
@ -1679,25 +1693,25 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
num_col_classifier,tables, multiplier=5.)
|
||||
except:
|
||||
peaks_neg_fin2=[]
|
||||
|
||||
|
||||
if len(peaks_neg_fin1)>=len(peaks_neg_fin2):
|
||||
peaks_neg_fin=list(np.copy(peaks_neg_fin1))
|
||||
else:
|
||||
peaks_neg_fin=list(np.copy(peaks_neg_fin2))
|
||||
peaks_neg_fin=list(np.array(peaks_neg_fin)+peaks_neg_fin_early[i_n])
|
||||
|
||||
|
||||
if i_n!=(len(peaks_neg_fin_early)-2):
|
||||
peaks_neg_fin_rev.append(peaks_neg_fin_early[i_n+1])
|
||||
#print(peaks_neg_fin,'peaks_neg_fin')
|
||||
peaks_neg_fin_rev=peaks_neg_fin_rev+peaks_neg_fin
|
||||
|
||||
if len(peaks_neg_fin_rev)>=len(peaks_neg_fin_org):
|
||||
if len(peaks_neg_fin_rev)>=len(peaks_neg_fin_org):
|
||||
peaks_neg_fin=list(np.sort(peaks_neg_fin_rev))
|
||||
num_col=len(peaks_neg_fin)
|
||||
else:
|
||||
peaks_neg_fin=list(np.copy(peaks_neg_fin_org))
|
||||
num_col=len(peaks_neg_fin)
|
||||
|
||||
|
||||
#print(peaks_neg_fin,'peaks_neg_fin')
|
||||
except:
|
||||
pass
|
||||
|
@ -1709,16 +1723,16 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
cy_hor_some=matrix_new[:,5][ (matrix_new[:,9]==0) ]
|
||||
cy_hor_diff=matrix_new[:,7][ (matrix_new[:,9]==0) ]
|
||||
arg_org_hor_some=matrix_new[:,0][ (matrix_new[:,9]==0) ]
|
||||
|
||||
|
||||
if right2left_readingorder:
|
||||
x_max_hor_some_new = regions_without_separators.shape[1] - x_min_hor_some
|
||||
x_min_hor_some_new = regions_without_separators.shape[1] - x_max_hor_some
|
||||
x_min_hor_some =list(np.copy(x_min_hor_some_new))
|
||||
x_max_hor_some =list(np.copy(x_max_hor_some_new))
|
||||
|
||||
|
||||
peaks_neg_tot=return_points_with_boundies(peaks_neg_fin,0, regions_without_separators[:,:].shape[1])
|
||||
peaks_neg_tot_tables.append(peaks_neg_tot)
|
||||
|
||||
|
||||
reading_order_type, x_starting, x_ending, y_type_2, y_diff_type_2, \
|
||||
y_lines_without_mother, x_start_without_mother, x_end_without_mother, there_is_sep_with_child, \
|
||||
y_lines_with_child_without_mother, x_start_with_child_without_mother, x_end_with_child_without_mother, \
|
||||
|
@ -1735,7 +1749,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
try:
|
||||
y_grenze=int(splitter_y_new[i])+300
|
||||
#check if there is a big separator in this y_mains_sep_ohne_grenzen
|
||||
|
||||
|
||||
args_early_ys=np.arange(len(y_type_2))
|
||||
#print(args_early_ys,'args_early_ys')
|
||||
#print(int(splitter_y_new[i]),int(splitter_y_new[i+1]))
|
||||
|
@ -1764,13 +1778,13 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
boxes.append([0, peaks_neg_tot[len(peaks_neg_tot)-1],
|
||||
int(splitter_y_new[i]), int( np.max(y_diff_main_separator_up))])
|
||||
splitter_y_new[i]=[ np.max(y_diff_main_separator_up) ][0]
|
||||
|
||||
|
||||
#print(splitter_y_new[i],'splitter_y_new[i]')
|
||||
y_type_2 = y_type_2[args_to_be_kept]
|
||||
x_starting = x_starting[args_to_be_kept]
|
||||
x_ending = x_ending[args_to_be_kept]
|
||||
y_diff_type_2 = y_diff_type_2[args_to_be_kept]
|
||||
|
||||
|
||||
#print('galdiha')
|
||||
y_grenze=int(splitter_y_new[i])+200
|
||||
args_early_ys2=np.arange(len(y_type_2))
|
||||
|
@ -1791,7 +1805,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
x_ending_up[ij]))
|
||||
nodes_in = np.unique(nodes_in)
|
||||
#print(nodes_in,'nodes_in')
|
||||
|
||||
|
||||
if set(nodes_in)==set(range(len(peaks_neg_tot)-1)):
|
||||
pass
|
||||
elif set(nodes_in)==set(range(1, len(peaks_neg_tot)-1)):
|
||||
|
@ -1799,7 +1813,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
else:
|
||||
#print('burdaydikh')
|
||||
args_to_be_kept2=np.array(list( set(args_early_ys2)-set(args_up2) ))
|
||||
|
||||
|
||||
if len(args_to_be_kept2)>0:
|
||||
y_type_2 = y_type_2[args_to_be_kept2]
|
||||
x_starting = x_starting[args_to_be_kept2]
|
||||
|
@ -1816,7 +1830,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
nodes_in = np.unique(nodes_in)
|
||||
#print(nodes_in,'nodes_in2')
|
||||
#print(np.array(range(len(peaks_neg_tot)-1)),'np.array(range(len(peaks_neg_tot)-1))')
|
||||
|
||||
|
||||
if set(nodes_in)==set(range(len(peaks_neg_tot)-1)):
|
||||
pass
|
||||
elif set(nodes_in)==set(range(1,len(peaks_neg_tot)-1)):
|
||||
|
@ -1826,7 +1840,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
#print(args_early_ys,'args_early_ys')
|
||||
#print(args_up,'args_up')
|
||||
args_to_be_kept2=np.array(list( set(args_early_ys) - set(args_up) ))
|
||||
|
||||
|
||||
#print(args_to_be_kept2,'args_to_be_kept2')
|
||||
#print(len(y_type_2),len(x_starting),len(x_ending),len(y_diff_type_2))
|
||||
if len(args_to_be_kept2)>0:
|
||||
|
@ -1837,7 +1851,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
else:
|
||||
pass
|
||||
#print('burdaydikh2')
|
||||
|
||||
|
||||
#int(splitter_y_new[i])
|
||||
y_lines_by_order=[]
|
||||
x_start_by_order=[]
|
||||
|
@ -1898,7 +1912,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
list(range(x_start_without_mother[dj],
|
||||
x_end_without_mother[dj]))
|
||||
columns_covered_by_mothers = list(set(columns_covered_by_mothers))
|
||||
|
||||
|
||||
all_columns=np.arange(len(peaks_neg_tot)-1)
|
||||
columns_not_covered=list(set(all_columns) - set(columns_covered_by_mothers))
|
||||
y_type_2 = np.append(y_type_2, [int(splitter_y_new[i])] * (len(columns_not_covered) + len(x_start_without_mother)))
|
||||
|
@ -1908,14 +1922,14 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
x_starting = np.append(x_starting, x_start_without_mother)
|
||||
x_ending = np.append(x_ending, np.array(columns_not_covered) + 1)
|
||||
x_ending = np.append(x_ending, x_end_without_mother)
|
||||
|
||||
|
||||
columns_covered_by_with_child_no_mothers = []
|
||||
for dj in range(len(x_end_with_child_without_mother)):
|
||||
columns_covered_by_with_child_no_mothers = columns_covered_by_with_child_no_mothers + \
|
||||
list(range(x_start_with_child_without_mother[dj],
|
||||
x_end_with_child_without_mother[dj]))
|
||||
columns_covered_by_with_child_no_mothers = list(set(columns_covered_by_with_child_no_mothers))
|
||||
|
||||
|
||||
all_columns = np.arange(len(peaks_neg_tot)-1)
|
||||
columns_not_covered_child_no_mother = list(set(all_columns) - set(columns_covered_by_with_child_no_mothers))
|
||||
#indexes_to_be_spanned=[]
|
||||
|
@ -1952,17 +1966,17 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
x_diff_all_between_nm_wc = x_ending_all_between_nm_wc - x_starting_all_between_nm_wc
|
||||
if len(x_diff_all_between_nm_wc)>0:
|
||||
biggest=np.argmax(x_diff_all_between_nm_wc)
|
||||
|
||||
|
||||
columns_covered_by_mothers = []
|
||||
for dj in range(len(x_starting_all_between_nm_wc)):
|
||||
columns_covered_by_mothers = columns_covered_by_mothers + \
|
||||
list(range(x_starting_all_between_nm_wc[dj],
|
||||
x_ending_all_between_nm_wc[dj]))
|
||||
columns_covered_by_mothers = list(set(columns_covered_by_mothers))
|
||||
|
||||
|
||||
all_columns=np.arange(i_s_nc, x_end_biggest_column)
|
||||
columns_not_covered = list(set(all_columns) - set(columns_covered_by_mothers))
|
||||
|
||||
|
||||
should_longest_line_be_extended=0
|
||||
if (len(x_diff_all_between_nm_wc) > 0 and
|
||||
set(list(range(x_starting_all_between_nm_wc[biggest],
|
||||
|
@ -1980,11 +1994,11 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
y_all_between_nm_wc = y_all_between_nm_wc[indexes_remained_after_deleting_closed_lines]
|
||||
x_starting_all_between_nm_wc = x_starting_all_between_nm_wc[indexes_remained_after_deleting_closed_lines]
|
||||
x_ending_all_between_nm_wc = x_ending_all_between_nm_wc[indexes_remained_after_deleting_closed_lines]
|
||||
|
||||
|
||||
y_all_between_nm_wc = np.append(y_all_between_nm_wc, y_column_nc[i_c])
|
||||
x_starting_all_between_nm_wc = np.append(x_starting_all_between_nm_wc, i_s_nc)
|
||||
x_ending_all_between_nm_wc = np.append(x_ending_all_between_nm_wc, x_end_biggest_column)
|
||||
|
||||
|
||||
if len(x_diff_all_between_nm_wc) > 0:
|
||||
try:
|
||||
y_all_between_nm_wc = np.append(y_all_between_nm_wc, y_column_nc[i_c])
|
||||
|
@ -1992,11 +2006,11 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
x_ending_all_between_nm_wc = np.append(x_ending_all_between_nm_wc, x_ending_all_between_nm_wc[biggest])
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
y_all_between_nm_wc = np.append(y_all_between_nm_wc, [y_column_nc[i_c]] * len(columns_not_covered))
|
||||
x_starting_all_between_nm_wc = np.append(x_starting_all_between_nm_wc, columns_not_covered)
|
||||
x_ending_all_between_nm_wc = np.append(x_ending_all_between_nm_wc, np.array(columns_not_covered) + 1)
|
||||
|
||||
|
||||
ind_args_between=np.arange(len(x_ending_all_between_nm_wc))
|
||||
for column in range(i_s_nc, x_end_biggest_column):
|
||||
ind_args_in_col=ind_args_between[x_starting_all_between_nm_wc==column]
|
||||
|
@ -2038,17 +2052,17 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
y_lines_by_order.append(y_col_sort[ii])
|
||||
x_start_by_order.append(x_start_column_sort[ii])
|
||||
x_end_by_order.append(x_end_column_sort[ii]-1)
|
||||
|
||||
|
||||
for il in range(len(y_lines_by_order)):
|
||||
y_copy = list(y_lines_by_order)
|
||||
x_start_copy = list(x_start_by_order)
|
||||
x_end_copy = list(x_end_by_order)
|
||||
|
||||
|
||||
#print(y_copy,'y_copy')
|
||||
y_itself=y_copy.pop(il)
|
||||
x_start_itself=x_start_copy.pop(il)
|
||||
x_end_itself=x_end_copy.pop(il)
|
||||
|
||||
|
||||
#print(y_copy,'y_copy2')
|
||||
for column in range(x_start_itself, x_end_itself+1):
|
||||
#print(column,'cols')
|
||||
|
@ -2065,7 +2079,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
y_down=np.min(y_in_cols)
|
||||
else:
|
||||
y_down=[int(splitter_y_new[i+1])][0]
|
||||
#print(y_itself,'y_itself')
|
||||
#print(y_itself,'y_itself')
|
||||
boxes.append([peaks_neg_tot[column],
|
||||
peaks_neg_tot[column+1],
|
||||
y_itself,
|
||||
|
@ -2108,7 +2122,7 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
##x_start_by_order = np.append(x_start_by_order, [0] * len(columns_not_covered))
|
||||
x_starting = np.append(x_starting, columns_not_covered)
|
||||
x_ending = np.append(x_ending, np.array(columns_not_covered) + 1)
|
||||
|
||||
|
||||
ind_args=np.array(range(len(y_type_2)))
|
||||
#ind_args=np.array(ind_args)
|
||||
for column in range(len(peaks_neg_tot)-1):
|
||||
|
@ -2130,17 +2144,17 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
y_lines_by_order.append(y_col_sort[ii])
|
||||
x_start_by_order.append(x_start_column_sort[ii])
|
||||
x_end_by_order.append(x_end_column_sort[ii]-1)
|
||||
|
||||
|
||||
for il in range(len(y_lines_by_order)):
|
||||
y_copy = list(y_lines_by_order)
|
||||
x_start_copy = list(x_start_by_order)
|
||||
x_end_copy = list(x_end_by_order)
|
||||
|
||||
|
||||
#print(y_copy,'y_copy')
|
||||
y_itself=y_copy.pop(il)
|
||||
x_start_itself=x_start_copy.pop(il)
|
||||
x_end_itself=x_end_copy.pop(il)
|
||||
|
||||
|
||||
#print(y_copy,'y_copy2')
|
||||
for column in range(x_start_itself, x_end_itself+1):
|
||||
#print(column,'cols')
|
||||
|
@ -2157,22 +2171,22 @@ def return_boxes_of_images_by_order_of_reading_new(
|
|||
y_down=np.min(y_in_cols)
|
||||
else:
|
||||
y_down=[int(splitter_y_new[i+1])][0]
|
||||
#print(y_itself,'y_itself')
|
||||
#print(y_itself,'y_itself')
|
||||
boxes.append([peaks_neg_tot[column],
|
||||
peaks_neg_tot[column+1],
|
||||
y_itself,
|
||||
y_down])
|
||||
#else:
|
||||
#boxes.append([ 0, regions_without_separators[:,:].shape[1] ,splitter_y_new[i],splitter_y_new[i+1]])
|
||||
|
||||
if right2left_readingorder:
|
||||
|
||||
if right2left_readingorder:
|
||||
peaks_neg_tot_tables_new = []
|
||||
if len(peaks_neg_tot_tables)>=1:
|
||||
for peaks_tab_ind in peaks_neg_tot_tables:
|
||||
peaks_neg_tot_tables_ind = regions_without_separators.shape[1] - np.array(peaks_tab_ind)
|
||||
peaks_neg_tot_tables_ind = list(peaks_neg_tot_tables_ind[::-1])
|
||||
peaks_neg_tot_tables_new.append(peaks_neg_tot_tables_ind)
|
||||
|
||||
|
||||
for i in range(len(boxes)):
|
||||
x_start_new = regions_without_separators.shape[1] - boxes[i][1]
|
||||
x_end_new = regions_without_separators.shape[1] - boxes[i][0]
|
||||
|
|
|
@ -227,9 +227,12 @@ def get_textregion_contours_in_org_image_light_old(cnts, img, slope_first):
|
|||
|
||||
return cnts_org
|
||||
|
||||
def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first):
|
||||
def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first, confidence_matrix):
|
||||
img_copy = np.zeros(img.shape)
|
||||
img_copy = cv2.fillPoly(img_copy, pts=[contour_par], color=(1, 1, 1))
|
||||
|
||||
confidence_matrix_mapped_with_contour = confidence_matrix * img_copy[:,:,0]
|
||||
confidence_contour = np.sum(confidence_matrix_mapped_with_contour) / float(np.sum(img_copy[:,:,0]))
|
||||
|
||||
img_copy = rotation_image_new(img_copy, -slope_first).astype(np.uint8)
|
||||
imgray = cv2.cvtColor(img_copy, cv2.COLOR_BGR2GRAY)
|
||||
|
@ -239,11 +242,13 @@ def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first
|
|||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
||||
# print(np.shape(cont_int[0]))
|
||||
return cont_int[0], index_r_con
|
||||
return cont_int[0], index_r_con, confidence_contour
|
||||
|
||||
def get_textregion_contours_in_org_image_light(cnts, img, slope_first, map=map):
|
||||
def get_textregion_contours_in_org_image_light(cnts, img, slope_first, confidence_matrix, map=map):
|
||||
if not len(cnts):
|
||||
return []
|
||||
return [], []
|
||||
|
||||
confidence_matrix = cv2.resize(confidence_matrix, (int(img.shape[1]/6), int(img.shape[0]/6)), interpolation=cv2.INTER_NEAREST)
|
||||
img = cv2.resize(img, (int(img.shape[1]/6), int(img.shape[0]/6)), interpolation=cv2.INTER_NEAREST)
|
||||
##cnts = list( (np.array(cnts)/2).astype(np.int16) )
|
||||
#cnts = cnts/2
|
||||
|
@ -251,10 +256,11 @@ def get_textregion_contours_in_org_image_light(cnts, img, slope_first, map=map):
|
|||
results = map(partial(do_back_rotation_and_get_cnt_back,
|
||||
img=img,
|
||||
slope_first=slope_first,
|
||||
confidence_matrix=confidence_matrix,
|
||||
),
|
||||
cnts, range(len(cnts)))
|
||||
contours, indexes = tuple(zip(*results))
|
||||
return [i*6 for i in contours]
|
||||
contours, indexes, conf_contours = tuple(zip(*results))
|
||||
return [i*6 for i in contours], list(conf_contours)
|
||||
|
||||
def return_contours_of_interested_textline(region_pre_p, pixel):
|
||||
# pixels of images are identified by 5
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import math
|
||||
|
||||
import imutils
|
||||
import cv2
|
||||
|
||||
def rotatedRectWithMaxArea(w, h, angle):
|
||||
|
@ -35,14 +33,14 @@ def rotate_max_area_new(image, rotated, angle):
|
|||
return rotated[y1:y2, x1:x2]
|
||||
|
||||
def rotation_image_new(img, thetha):
|
||||
rotated = imutils.rotate(img, thetha)
|
||||
rotated = rotate_image(img, thetha)
|
||||
return rotate_max_area_new(img, rotated, thetha)
|
||||
|
||||
def rotate_image(img_patch, slope):
|
||||
(h, w) = img_patch.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
M = cv2.getRotationMatrix2D(center, slope, 1.0)
|
||||
return cv2.warpAffine(img_patch, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
|
||||
return cv2.warpAffine(img_patch, M, (w, h) )
|
||||
|
||||
def rotate_image_different( img, slope):
|
||||
# img = cv2.imread('images/input.jpg')
|
||||
|
@ -62,17 +60,17 @@ def rotate_max_area(image, rotated, rotated_textline, rotated_layout, rotated_ta
|
|||
return rotated[y1:y2, x1:x2], rotated_textline[y1:y2, x1:x2], rotated_layout[y1:y2, x1:x2], rotated_table_prediction[y1:y2, x1:x2]
|
||||
|
||||
def rotation_not_90_func(img, textline, text_regions_p_1, table_prediction, thetha):
|
||||
rotated = imutils.rotate(img, thetha)
|
||||
rotated_textline = imutils.rotate(textline, thetha)
|
||||
rotated_layout = imutils.rotate(text_regions_p_1, thetha)
|
||||
rotated_table_prediction = imutils.rotate(table_prediction, thetha)
|
||||
rotated = rotate_image(img, thetha)
|
||||
rotated_textline = rotate_image(textline, thetha)
|
||||
rotated_layout = rotate_image(text_regions_p_1, thetha)
|
||||
rotated_table_prediction = rotate_image(table_prediction, thetha)
|
||||
return rotate_max_area(img, rotated, rotated_textline, rotated_layout, rotated_table_prediction, thetha)
|
||||
|
||||
def rotation_not_90_func_full_layout(img, textline, text_regions_p_1, text_regions_p_fully, thetha):
|
||||
rotated = imutils.rotate(img, thetha)
|
||||
rotated_textline = imutils.rotate(textline, thetha)
|
||||
rotated_layout = imutils.rotate(text_regions_p_1, thetha)
|
||||
rotated_layout_full = imutils.rotate(text_regions_p_fully, thetha)
|
||||
rotated = rotate_image(img, thetha)
|
||||
rotated_textline = rotate_image(textline, thetha)
|
||||
rotated_layout = rotate_image(text_regions_p_1, thetha)
|
||||
rotated_layout_full = rotate_image(text_regions_p_fully, thetha)
|
||||
return rotate_max_area_full_layout(img, rotated, rotated_textline, rotated_layout, rotated_layout_full, thetha)
|
||||
|
||||
def rotate_max_area_full_layout(image, rotated, rotated_textline, rotated_layout, rotated_layout_full, angle):
|
||||
|
|
|
@ -139,7 +139,7 @@ class EynollahXmlWriter():
|
|||
points_co += str(int((contour_textline[0][1] + region_bboxes[0]+page_coord[0])/self.scale_y))
|
||||
points_co += ' '
|
||||
coords.set_points(points_co[:-1])
|
||||
|
||||
|
||||
def serialize_lines_in_dropcapital(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion):
|
||||
self.logger.debug('enter serialize_lines_in_region')
|
||||
for j in range(1):
|
||||
|
@ -168,7 +168,7 @@ class EynollahXmlWriter():
|
|||
with open(self.output_filename, 'w') as f:
|
||||
f.write(to_xml(pcgts))
|
||||
|
||||
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines):
|
||||
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines, conf_contours_textregion):
|
||||
self.logger.debug('enter build_pagexml_no_full_layout')
|
||||
|
||||
# create the file structure
|
||||
|
@ -184,8 +184,9 @@ class EynollahXmlWriter():
|
|||
|
||||
for mm in range(len(found_polygons_text_region)):
|
||||
textregion = TextRegionType(id=counter.next_region_id, type_='paragraph',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)),
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord), conf=conf_contours_textregion[mm]),
|
||||
)
|
||||
#textregion.set_conf(conf_contours_textregion[mm])
|
||||
page.add_TextRegion(textregion)
|
||||
if ocr_all_textlines:
|
||||
ocr_textlines = ocr_all_textlines[mm]
|
||||
|
@ -215,9 +216,9 @@ class EynollahXmlWriter():
|
|||
points_co += ','
|
||||
points_co += str(int((found_polygons_text_region_img[mm][lmm][1] + page_coord[0])/ self.scale_y ))
|
||||
points_co += ' '
|
||||
|
||||
|
||||
img_region.get_Coords().set_points(points_co[:-1])
|
||||
|
||||
|
||||
for mm in range(len(polygons_lines_to_be_written_in_xml)):
|
||||
sep_hor = SeparatorRegionType(id=counter.next_region_id, Coords=CoordsType())
|
||||
page.add_SeparatorRegion(sep_hor)
|
||||
|
@ -241,7 +242,7 @@ class EynollahXmlWriter():
|
|||
|
||||
return pcgts
|
||||
|
||||
def build_pagexml_full_layout(self, found_polygons_text_region, found_polygons_text_region_h, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, ocr_all_textlines):
|
||||
def build_pagexml_full_layout(self, found_polygons_text_region, found_polygons_text_region_h, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, ocr_all_textlines, conf_contours_textregion, conf_contours_textregion_h):
|
||||
self.logger.debug('enter build_pagexml_full_layout')
|
||||
|
||||
# create the file structure
|
||||
|
@ -256,9 +257,9 @@ class EynollahXmlWriter():
|
|||
|
||||
for mm in range(len(found_polygons_text_region)):
|
||||
textregion = TextRegionType(id=counter.next_region_id, type_='paragraph',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)))
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord), conf=conf_contours_textregion[mm]))
|
||||
page.add_TextRegion(textregion)
|
||||
|
||||
|
||||
if ocr_all_textlines:
|
||||
ocr_textlines = ocr_all_textlines[mm]
|
||||
else:
|
||||
|
@ -293,10 +294,10 @@ class EynollahXmlWriter():
|
|||
|
||||
for mm in range(len(found_polygons_text_region_img)):
|
||||
page.add_ImageRegion(ImageRegionType(id=counter.next_region_id, Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region_img[mm], page_coord))))
|
||||
|
||||
|
||||
for mm in range(len(polygons_lines_to_be_written_in_xml)):
|
||||
page.add_SeparatorRegion(ImageRegionType(id=counter.next_region_id, Coords=CoordsType(points=self.calculate_polygon_coords(polygons_lines_to_be_written_in_xml[mm], [0 , 0, 0, 0]))))
|
||||
|
||||
|
||||
for mm in range(len(found_polygons_tables)):
|
||||
page.add_TableRegion(TableRegionType(id=counter.next_region_id, Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_tables[mm], page_coord))))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue