docorating eynollah with textregion confidence score #135

main
vahidrezanezhad 1 day ago
parent 91b2201b07
commit 6b52da227c

@ -1214,7 +1214,7 @@ class Eynollah:
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8)
return prediction_true
return prediction_true , resize_image(label_p_pred[0, :, :, 1] , img_h_page, img_w_page)
if img.shape[0] < img_height_model:
img = resize_image(img, img_height_model, img.shape[1])
@ -1230,6 +1230,7 @@ class Eynollah:
img_h = img.shape[0]
img_w = img.shape[1]
prediction_true = np.zeros((img_h, img_w, 3))
confidence_matrix = np.zeros((img_h, img_w))
mask_true = np.zeros((img_h, img_w))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)
@ -1318,54 +1319,99 @@ class Eynollah:
seg_in[0:-margin or None,
0:-margin or None,
np.newaxis]
confidence_matrix[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin] = \
label_p_pred[0, 0:-margin or None,
0:-margin or None,
1]
elif i_batch == nxf - 1 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[margin:,
margin:,
np.newaxis]
confidence_matrix[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - 0] = \
label_p_pred[0, margin:,
margin:,
1]
elif i_batch == 0 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + 0:index_x_u_in - margin] = \
seg_in[margin:,
0:-margin or None,
np.newaxis]
confidence_matrix[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + 0:index_x_u_in - margin] = \
label_p_pred[0, margin:,
0:-margin or None,
1]
elif i_batch == nxf - 1 and j_batch == 0:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[0:-margin or None,
margin:,
np.newaxis]
confidence_matrix[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0] = \
label_p_pred[0, 0:-margin or None,
margin:,
1]
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin] = \
seg_in[margin:-margin or None,
0:-margin or None,
np.newaxis]
confidence_matrix[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin] = \
label_p_pred[0, margin:-margin or None,
0:-margin or None,
1]
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[margin:-margin or None,
margin:,
np.newaxis]
confidence_matrix[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0] = \
label_p_pred[0, margin:-margin or None,
margin:,
1]
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[0:-margin or None,
margin:-margin or None,
np.newaxis]
confidence_matrix[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin] = \
label_p_pred[0, 0:-margin or None,
margin:-margin or None,
1]
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[margin:,
margin:-margin or None,
np.newaxis]
confidence_matrix[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - margin] = \
label_p_pred[0, margin:,
margin:-margin or None,
1]
else:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[margin:-margin or None,
margin:-margin or None,
np.newaxis]
confidence_matrix[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin] = \
label_p_pred[0, margin:-margin or None,
margin:-margin or None,
1]
indexer_inside_batch += 1
list_i_s = []
@ -1380,7 +1426,7 @@ class Eynollah:
prediction_true = prediction_true.astype(np.uint8)
gc.collect()
return prediction_true
return prediction_true, confidence_matrix
def extract_page(self):
self.logger.debug("enter extract_page")
@ -1742,7 +1788,7 @@ class Eynollah:
if not self.dir_in:
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens_light_only_images_extraction)
prediction_regions_org = self.do_prediction_new_concept(True, img_resized, self.model_region)
prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_region)
prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h )
image_page, page_coord, cont_page = self.extract_page()
@ -1903,24 +1949,26 @@ class Eynollah:
if self.image_org.shape[0]/self.image_org.shape[1] > 2.5:
self.logger.debug("resized to %dx%d for %d cols",
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
prediction_regions_org = self.do_prediction_new_concept(
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
True, img_resized, self.model_region_1_2, n_batch_inference=1,
thresholding_for_some_classes_in_light_version=True)
else:
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3))
prediction_regions_page = self.do_prediction_new_concept(
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept(
False, self.image_page_org_size, self.model_region_1_2, n_batch_inference=1,
thresholding_for_artificial_class_in_light_version=True)
ys = slice(*self.page_coord[0:2])
xs = slice(*self.page_coord[2:4])
prediction_regions_org[ys, xs] = prediction_regions_page
confidence_matrix[ys, xs] = confidence_matrix_page
else:
new_h = (900+ (num_col_classifier-3)*100)
img_resized = resize_image(img_bin, int(new_h * img_bin.shape[0] /img_bin.shape[1]), new_h)
self.logger.debug("resized to %dx%d (new_h=%d) for %d cols",
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
prediction_regions_org = self.do_prediction_new_concept(
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
True, img_resized, self.model_region_1_2, n_batch_inference=2,
thresholding_for_some_classes_in_light_version=True)
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_region, n_batch_inference=3, thresholding_for_some_classes_in_light_version=True)
@ -1928,8 +1976,9 @@ class Eynollah:
#plt.imshow(prediction_regions_org[:,:,0])
#plt.show()
prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h )
img_bin = resize_image(img_bin,img_height_h, img_width_h )
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
confidence_matrix = resize_image(confidence_matrix, img_height_h, img_width_h )
img_bin = resize_image(img_bin, img_height_h, img_width_h )
prediction_regions_org=prediction_regions_org[:,:,0]
mask_lines_only = (prediction_regions_org[:,:] ==3)*1
@ -1985,11 +2034,11 @@ class Eynollah:
#plt.show()
#print("inside 4 ", time.time()-t_in)
self.logger.debug("exit get_regions_light_v")
return text_regions_p_true, erosion_hurts, polygons_lines_xml, textline_mask_tot_ea, img_bin
return text_regions_p_true, erosion_hurts, polygons_lines_xml, textline_mask_tot_ea, img_bin, confidence_matrix
else:
img_bin = resize_image(img_bin,img_height_h, img_width_h )
self.logger.debug("exit get_regions_light_v")
return None, erosion_hurts, None, textline_mask_tot_ea, img_bin
return None, erosion_hurts, None, textline_mask_tot_ea, img_bin, None
def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):
self.logger.debug("enter get_regions_from_xy_2models")
@ -2742,7 +2791,7 @@ class Eynollah:
patches = False
if self.light_version:
prediction_table = self.do_prediction_new_concept(patches, img, self.model_table)
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_table)
prediction_table = prediction_table.astype(np.int16)
return prediction_table[:,:,0]
else:
@ -4127,8 +4176,7 @@ class Eynollah:
return contours
def filter_contours_without_textline_inside(
self, contours,text_con_org, contours_textline, contours_only_text_parent_d_ordered):
self, contours,text_con_org, contours_textline, contours_only_text_parent_d_ordered, conf_contours_textregions):
###contours_txtline_of_all_textregions = []
###for jj in range(len(contours_textline)):
###contours_txtline_of_all_textregions = contours_txtline_of_all_textregions + contours_textline[jj]
@ -4161,13 +4209,14 @@ class Eynollah:
uniqe_args_trs_sorted = np.sort(uniqe_args_trs)[::-1]
for ind_u_a_trs in uniqe_args_trs_sorted:
conf_contours_textregions.pop(ind_u_a_trs)
contours.pop(ind_u_a_trs)
contours_textline.pop(ind_u_a_trs)
text_con_org.pop(ind_u_a_trs)
if len(contours_only_text_parent_d_ordered) > 0:
contours_only_text_parent_d_ordered.pop(ind_u_a_trs)
return contours, text_con_org, contours_textline, contours_only_text_parent_d_ordered, np.array(range(len(contours)))
return contours, text_con_org, conf_contours_textregions, contours_textline, contours_only_text_parent_d_ordered, np.array(range(len(contours)))
def dilate_textlines(self, all_found_textline_polygons):
for j in range(len(all_found_textline_polygons)):
@ -4347,7 +4396,7 @@ class Eynollah:
pcgts = self.writer.build_pagexml_no_full_layout(
[], page_coord, [], [], [], [],
polygons_of_images, [], [], [], [], [],
cont_page, [], [], ocr_all_textlines)
cont_page, [], [], ocr_all_textlines, [])
if self.plotter:
self.plotter.write_images_into_directory(polygons_of_images, image_page)
@ -4358,7 +4407,7 @@ class Eynollah:
return pcgts
if self.skip_layout_and_reading_order:
_ ,_, _, textline_mask_tot_ea, img_bin_light = \
_ ,_, _, textline_mask_tot_ea, img_bin_light,_ = \
self.get_regions_light_v(img_res, is_image_enhanced, num_col_classifier,
skip_layout_and_reading_order=self.skip_layout_and_reading_order)
@ -4392,11 +4441,12 @@ class Eynollah:
polygons_lines_xml = []
contours_tables = []
ocr_all_textlines = None
conf_contours_textregions =None
pcgts = self.writer.build_pagexml_no_full_layout(
cont_page, page_coord, order_text_new, id_of_texts_tot,
all_found_textline_polygons, page_coord, polygons_of_images, polygons_of_marginals,
all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals,
cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines)
cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines, conf_contours_textregions)
if self.dir_in:
self.writer.write_pagexml(pcgts)
continue
@ -4406,7 +4456,7 @@ class Eynollah:
#print("text region early -1 in %.1fs", time.time() - t0)
t1 = time.time()
if self.light_version:
text_regions_p_1 ,erosion_hurts, polygons_lines_xml, textline_mask_tot_ea, img_bin_light = \
text_regions_p_1 ,erosion_hurts, polygons_lines_xml, textline_mask_tot_ea, img_bin_light, confidence_matrix = \
self.get_regions_light_v(img_res, is_image_enhanced, num_col_classifier)
#print("text region early -2 in %.1fs", time.time() - t0)
if num_col_classifier == 1 or num_col_classifier ==2:
@ -4417,9 +4467,9 @@ class Eynollah:
img_h_new = img_w_new * textline_mask_tot_ea.shape[0] // textline_mask_tot_ea.shape[1]
textline_mask_tot_ea_deskew = resize_image(textline_mask_tot_ea,img_h_new, img_w_new )
slope_deskew, slope_first = self.run_deskew(textline_mask_tot_ea_deskew)
else:
ttest = time.time()
slope_deskew, slope_first = self.run_deskew(textline_mask_tot_ea)
#print("text region early -2,5 in %.1fs", time.time() - t0)
#self.logger.info("Textregion detection took %.1fs ", time.time() - t1t)
@ -4451,7 +4501,7 @@ class Eynollah:
ocr_all_textlines = None
pcgts = self.writer.build_pagexml_no_full_layout(
[], page_coord, [], [], [], [], [], [], [], [], [], [],
cont_page, [], [], ocr_all_textlines)
cont_page, [], [], ocr_all_textlines, [])
self.logger.info("Job done in %.1fs", time.time() - t1)
if self.dir_in:
self.writer.write_pagexml(pcgts)
@ -4636,13 +4686,13 @@ class Eynollah:
[], [], page_coord, [], [], [], [], [], [],
polygons_of_images, contours_tables, [],
polygons_of_marginals, empty_marginals, empty_marginals, [], [], [],
cont_page, polygons_lines_xml, [])
cont_page, polygons_lines_xml, [], [], [])
else:
pcgts = self.writer.build_pagexml_no_full_layout(
[], page_coord, [], [], [], [],
polygons_of_images,
polygons_of_marginals, empty_marginals, empty_marginals, [], [],
cont_page, polygons_lines_xml, contours_tables, [])
cont_page, polygons_lines_xml, contours_tables, [], [])
self.logger.info("Job done in %.1fs", time.time() - t0)
if self.dir_in:
self.writer.write_pagexml(pcgts)
@ -4663,10 +4713,11 @@ class Eynollah:
contours_only_text_parent , contours_only_text_parent_d_ordered = self.filter_contours_inside_a_bigger_one(
contours_only_text_parent, contours_only_text_parent_d_ordered, text_only, marginal_cnts=polygons_of_marginals)
#print("text region early 3.5 in %.1fs", time.time() - t0)
txt_con_org = get_textregion_contours_in_org_image_light(
contours_only_text_parent, self.image, slope_first, map=self.executor.map)
txt_con_org , conf_contours_textregions = get_textregion_contours_in_org_image_light(
contours_only_text_parent, self.image, slope_first, confidence_matrix, map=self.executor.map)
#txt_con_org = self.dilate_textregions_contours(txt_con_org)
#contours_only_text_parent = self.dilate_textregions_contours(contours_only_text_parent)
else:
txt_con_org = get_textregion_contours_in_org_image(
contours_only_text_parent, self.image, slope_first)
@ -4701,9 +4752,9 @@ class Eynollah:
all_found_textline_polygons, None, textline_mask_tot_ea_org, type_contour="textline")
all_found_textline_polygons_marginals = self.dilate_textregions_contours_textline_version(
all_found_textline_polygons_marginals)
contours_only_text_parent, txt_con_org, all_found_textline_polygons, contours_only_text_parent_d_ordered, \
contours_only_text_parent, txt_con_org, conf_contours_textregions, all_found_textline_polygons, contours_only_text_parent_d_ordered, \
index_by_text_par_con = self.filter_contours_without_textline_inside(
contours_only_text_parent, txt_con_org, all_found_textline_polygons, contours_only_text_parent_d_ordered)
contours_only_text_parent, txt_con_org, all_found_textline_polygons, contours_only_text_parent_d_ordered, conf_contours_textregions)
else:
textline_mask_tot_ea = cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=1)
all_found_textline_polygons, boxes_text, txt_con_org, contours_only_text_parent, all_box_coord, \
@ -4761,12 +4812,14 @@ class Eynollah:
if self.light_version:
fun = check_any_text_region_in_model_one_is_main_or_header_light
else:
conf_contours_textregions = None
fun = check_any_text_region_in_model_one_is_main_or_header
text_regions_p, contours_only_text_parent, contours_only_text_parent_h, all_box_coord, all_box_coord_h, \
all_found_textline_polygons, all_found_textline_polygons_h, slopes, slopes_h, \
contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered = fun(
contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered, \
conf_contours_textregions, conf_contours_textregions_h = fun(
text_regions_p, regions_fully, contours_only_text_parent,
all_box_coord, all_found_textline_polygons, slopes, contours_only_text_parent_d_ordered)
all_box_coord, all_found_textline_polygons, slopes, contours_only_text_parent_d_ordered, conf_contours_textregions)
if self.plotter:
self.plotter.save_plot_of_layout(text_regions_p, image_page)
@ -4843,7 +4896,7 @@ class Eynollah:
all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h,
polygons_of_images, contours_tables, polygons_of_drop_capitals, polygons_of_marginals,
all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals,
cont_page, polygons_lines_xml, ocr_all_textlines)
cont_page, polygons_lines_xml, ocr_all_textlines, conf_contours_textregions, conf_contours_textregions_h)
self.logger.info("Job done in %.1fs", time.time() - t0)
#print("Job done in %.1fs", time.time() - t0)
if self.dir_in:
@ -4929,7 +4982,7 @@ class Eynollah:
txt_con_org, page_coord, order_text_new, id_of_texts_tot,
all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals,
all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals,
cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines)
cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines, conf_contours_textregions)
#print("Job done in %.1fs" % (time.time() - t0))
self.logger.info("Job done in %.1fs", time.time() - t0)
if not self.dir_in:

@ -865,7 +865,7 @@ def check_any_text_region_in_model_one_is_main_or_header(
contours_only_text_parent,
all_box_coord, all_found_textline_polygons,
slopes,
contours_only_text_parent_d_ordered):
contours_only_text_parent_d_ordered, conf_contours):
cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin = \
find_new_features_of_contours(contours_only_text_parent)
@ -926,14 +926,17 @@ def check_any_text_region_in_model_one_is_main_or_header(
slopes_main,
slopes_head,
contours_only_text_parent_main_d,
contours_only_text_parent_head_d)
contours_only_text_parent_head_d,
None,
None)
def check_any_text_region_in_model_one_is_main_or_header_light(
regions_model_1, regions_model_full,
contours_only_text_parent,
all_box_coord, all_found_textline_polygons,
slopes,
contours_only_text_parent_d_ordered):
contours_only_text_parent_d_ordered,
conf_contours):
### to make it faster
h_o = regions_model_1.shape[0]
@ -965,6 +968,9 @@ def check_any_text_region_in_model_one_is_main_or_header_light(
contours_only_text_parent_main=[]
contours_only_text_parent_head=[]
conf_contours_main=[]
conf_contours_head=[]
contours_only_text_parent_main_d=[]
contours_only_text_parent_head_d=[]
@ -987,9 +993,11 @@ def check_any_text_region_in_model_one_is_main_or_header_light(
all_box_coord_head.append(all_box_coord[ii])
slopes_head.append(slopes[ii])
all_found_textline_polygons_head.append(all_found_textline_polygons[ii])
conf_contours_head.append(None)
else:
regions_model_1[:,:][(regions_model_1[:,:]==1) & (img[:,:,0]==255) ]=1
contours_only_text_parent_main.append(con)
conf_contours_main.append(conf_contours[ii])
if contours_only_text_parent_d_ordered is not None:
contours_only_text_parent_main_d.append(contours_only_text_parent_d_ordered[ii])
all_box_coord_main.append(all_box_coord[ii])
@ -1017,7 +1025,9 @@ def check_any_text_region_in_model_one_is_main_or_header_light(
slopes_main,
slopes_head,
contours_only_text_parent_main_d,
contours_only_text_parent_head_d)
contours_only_text_parent_head_d,
conf_contours_main,
conf_contours_head)
def small_textlines_to_parent_adherence2(textlines_con, textline_iamge, num_col):
# print(textlines_con)

@ -227,9 +227,12 @@ def get_textregion_contours_in_org_image_light_old(cnts, img, slope_first):
return cnts_org
def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first):
def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first, confidence_matrix):
img_copy = np.zeros(img.shape)
img_copy = cv2.fillPoly(img_copy, pts=[contour_par], color=(1, 1, 1))
confidence_matrix_mapped_with_contour = confidence_matrix * img_copy[:,:,0]
confidence_contour = np.sum(confidence_matrix_mapped_with_contour) / float(np.sum(img_copy[:,:,0]))
img_copy = rotation_image_new(img_copy, -slope_first).astype(np.uint8)
imgray = cv2.cvtColor(img_copy, cv2.COLOR_BGR2GRAY)
@ -239,11 +242,13 @@ def do_back_rotation_and_get_cnt_back(contour_par, index_r_con, img, slope_first
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
# print(np.shape(cont_int[0]))
return cont_int[0], index_r_con
return cont_int[0], index_r_con, confidence_contour
def get_textregion_contours_in_org_image_light(cnts, img, slope_first, map=map):
def get_textregion_contours_in_org_image_light(cnts, img, slope_first, confidence_matrix, map=map):
if not len(cnts):
return []
return [], []
confidence_matrix = cv2.resize(confidence_matrix, (int(img.shape[1]/6), int(img.shape[0]/6)), interpolation=cv2.INTER_NEAREST)
img = cv2.resize(img, (int(img.shape[1]/6), int(img.shape[0]/6)), interpolation=cv2.INTER_NEAREST)
##cnts = list( (np.array(cnts)/2).astype(np.int16) )
#cnts = cnts/2
@ -251,10 +256,11 @@ def get_textregion_contours_in_org_image_light(cnts, img, slope_first, map=map):
results = map(partial(do_back_rotation_and_get_cnt_back,
img=img,
slope_first=slope_first,
confidence_matrix=confidence_matrix,
),
cnts, range(len(cnts)))
contours, indexes = tuple(zip(*results))
return [i*6 for i in contours]
contours, indexes, conf_contours = tuple(zip(*results))
return [i*6 for i in contours], list(conf_contours)
def return_contours_of_interested_textline(region_pre_p, pixel):
# pixels of images are identified by 5

@ -168,7 +168,7 @@ class EynollahXmlWriter():
with open(self.output_filename, 'w') as f:
f.write(to_xml(pcgts))
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines):
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines, conf_contours_textregion):
self.logger.debug('enter build_pagexml_no_full_layout')
# create the file structure
@ -184,8 +184,9 @@ class EynollahXmlWriter():
for mm in range(len(found_polygons_text_region)):
textregion = TextRegionType(id=counter.next_region_id, type_='paragraph',
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)),
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord), conf=conf_contours_textregion[mm]),
)
#textregion.set_conf(conf_contours_textregion[mm])
page.add_TextRegion(textregion)
if ocr_all_textlines:
ocr_textlines = ocr_all_textlines[mm]
@ -241,7 +242,7 @@ class EynollahXmlWriter():
return pcgts
def build_pagexml_full_layout(self, found_polygons_text_region, found_polygons_text_region_h, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, ocr_all_textlines):
def build_pagexml_full_layout(self, found_polygons_text_region, found_polygons_text_region_h, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, ocr_all_textlines, conf_contours_textregion, conf_contours_textregion_h):
self.logger.debug('enter build_pagexml_full_layout')
# create the file structure
@ -256,7 +257,7 @@ class EynollahXmlWriter():
for mm in range(len(found_polygons_text_region)):
textregion = TextRegionType(id=counter.next_region_id, type_='paragraph',
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)))
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord), conf=conf_contours_textregion[mm]))
page.add_TextRegion(textregion)
if ocr_all_textlines:

Loading…
Cancel
Save