From 2e7c69f2ac4d3c1aa68498fae409e8d6f25ebf8b Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 28 May 2024 16:48:51 +0200 Subject: [PATCH] inference for reading order --- gt_gen_utils.py | 136 +++++++++++---------------------- inference.py | 196 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 228 insertions(+), 104 deletions(-) diff --git a/gt_gen_utils.py b/gt_gen_utils.py index 9dc8377..0286ac7 100644 --- a/gt_gen_utils.py +++ b/gt_gen_utils.py @@ -38,11 +38,8 @@ def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, m polygon = geometry.Polygon([point[0] for point in c]) # area = cv2.contourArea(c) area = polygon.area - ##print(np.prod(thresh.shape[:2])) # Check that polygon has area greater than minimal area - # print(hierarchy[0][jv][3],hierarchy ) if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : - # print(c[0][0][1]) found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.int32)) jv += 1 return found_polygons_early @@ -52,15 +49,12 @@ def filter_contours_area_of_image(image, contours, order_index, max_area, min_ar order_index_filtered = list() #jv = 0 for jv, c in enumerate(contours): - #print(len(c[0])) c = c[0] if len(c) < 3: # A polygon cannot have less than 3 points continue c_e = [point for point in c] - #print(c_e) polygon = geometry.Polygon(c_e) area = polygon.area - #print(area,'area') if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 : found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.uint)) order_index_filtered.append(order_index[jv]) @@ -88,12 +82,8 @@ def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002): def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len): co_text_eroded = [] for con in co_text: - #try: img_boundary_in = np.zeros( (y_len,x_len) ) img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1)) - #print('bidiahhhhaaa') - - #img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica if erosion_rate > 0: @@ -626,8 +616,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ def find_new_features_of_contours(contours_main): - - #print(contours_main[0][0][:, 0]) areas_main = np.array([cv2.contourArea(contours_main[j]) for j in range(len(contours_main))]) M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))] @@ -658,8 +646,6 @@ def find_new_features_of_contours(contours_main): y_min_main = np.array([np.min(contours_main[j][:, 1]) for j in range(len(contours_main))]) y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))]) - # dis_x=np.abs(x_max_main-x_min_main) - return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin def read_xml(xml_file): file_name = Path(xml_file).stem @@ -675,13 +661,11 @@ def read_xml(xml_file): y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) - for jj in root1.iter(link+'RegionRefIndexed'): index_tot_regions.append(jj.attrib['index']) tot_region_ref.append(jj.attrib['regionRef']) region_tags=np.unique([x for x in alltags if x.endswith('Region')]) - #print(region_tags) co_text_paragraph=[] co_text_drop=[] co_text_heading=[] @@ -698,7 +682,6 @@ def read_xml(xml_file): co_graphic_decoration=[] co_noise=[] - co_text_paragraph_text=[] co_text_drop_text=[] co_text_heading_text=[] @@ -715,7 +698,6 @@ def read_xml(xml_file): co_graphic_decoration_text=[] co_noise_text=[] - id_paragraph = [] id_header = [] id_heading = [] @@ -726,14 +708,8 @@ def read_xml(xml_file): for nn in root1.iter(tag): for child2 in nn: tag2 = child2.tag - #print(child2.tag) if tag2.endswith('}TextEquiv') or tag2.endswith('}TextEquiv'): - #children2 = childtext.getchildren() - #rank = child2.find('Unicode').text for childtext2 in child2: - #rank = childtext2.find('Unicode').text - #if childtext2.tag.endswith('}PlainText') or childtext2.tag.endswith('}PlainText'): - #print(childtext2.text) if childtext2.tag.endswith('}Unicode') or childtext2.tag.endswith('}Unicode'): if "type" in nn.attrib and nn.attrib['type']=='drop-capital': co_text_drop_text.append(childtext2.text) @@ -743,10 +719,10 @@ def read_xml(xml_file): co_text_signature_mark_text.append(childtext2.text) elif "type" in nn.attrib and nn.attrib['type']=='header': co_text_header_text.append(childtext2.text) - elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - co_text_catch_text.append(childtext2.text) - elif "type" in nn.attrib and nn.attrib['type']=='page-number': - co_text_page_number_text.append(childtext2.text) + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###co_text_catch_text.append(childtext2.text) + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': + ###co_text_page_number_text.append(childtext2.text) elif "type" in nn.attrib and nn.attrib['type']=='marginalia': co_text_marginalia_text.append(childtext2.text) else: @@ -774,7 +750,6 @@ def read_xml(xml_file): if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - #if nn.attrib['type']=='paragraph': c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) @@ -792,27 +767,22 @@ def read_xml(xml_file): c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - elif "type" in nn.attrib and nn.attrib['type']=='page-number': + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': - c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) + ###c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) elif "type" in nn.attrib and nn.attrib['type']=='marginalia': id_marginalia.append(nn.attrib['id']) c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) else: - #print(nn.attrib['id']) - id_paragraph.append(nn.attrib['id']) c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) break else: @@ -821,7 +791,6 @@ def read_xml(xml_file): if vv.tag==link+'Point': if "type" in nn.attrib and nn.attrib['type']=='drop-capital': - #if nn.attrib['type']=='paragraph': c_t_in_drop.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -835,7 +804,6 @@ def read_xml(xml_file): elif "type" in nn.attrib and nn.attrib['type']=='signature-mark': c_t_in_signature_mark.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='header': id_header.append(nn.attrib['id']) @@ -843,33 +811,26 @@ def read_xml(xml_file): sumi+=1 - elif "type" in nn.attrib and nn.attrib['type']=='catch-word': - c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - sumi+=1 + ###elif "type" in nn.attrib and nn.attrib['type']=='catch-word': + ###c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + ###sumi+=1 + ###elif "type" in nn.attrib and nn.attrib['type']=='page-number': - elif "type" in nn.attrib and nn.attrib['type']=='page-number': - - c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) - sumi+=1 + ###c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + ###sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='marginalia': id_marginalia.append(nn.attrib['id']) c_t_in_marginalia.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) sumi+=1 else: id_paragraph.append(nn.attrib['id']) c_t_in_paragraph.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) sumi+=1 - #c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - - #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: break @@ -895,7 +856,6 @@ def read_xml(xml_file): elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] c_t_in_text_annotation=[] @@ -907,40 +867,31 @@ def read_xml(xml_file): coords=bool(vv.attrib) if coords: p_h=vv.attrib['points'].split(' ') - #c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - #if nn.attrib['type']=='paragraph': - c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - + elif "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - #print(c_t_in_paragraph) + else: c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) - break else: pass if vv.tag==link+'Point': - if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation': - #if nn.attrib['type']=='paragraph': - c_t_in_text_annotation.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='decoration': - c_t_in_decoration.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) - #print(c_t_in_paragraph) sumi+=1 + else: c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -955,7 +906,6 @@ def read_xml(xml_file): elif tag.endswith('}ImageRegion') or tag.endswith('}imageregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] sumi=0 @@ -974,7 +924,6 @@ def read_xml(xml_file): if vv.tag==link+'Point': c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 - #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: break co_img.append(np.array(c_t_in)) @@ -982,7 +931,6 @@ def read_xml(xml_file): elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] sumi=0 @@ -1001,7 +949,6 @@ def read_xml(xml_file): if vv.tag==link+'Point': c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 - #print(vv.tag,'in') elif vv.tag!=link+'Point' and sumi>=1: break co_sep.append(np.array(c_t_in)) @@ -1009,7 +956,6 @@ def read_xml(xml_file): elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] sumi=0 @@ -1028,14 +974,13 @@ def read_xml(xml_file): if vv.tag==link+'Point': c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 - #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: break co_table.append(np.array(c_t_in)) co_table_text.append(' ') elif tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): - #print('sth') for nn in root1.iter(tag): c_t_in=[] sumi=0 @@ -1054,40 +999,22 @@ def read_xml(xml_file): if vv.tag==link+'Point': c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 - #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: break co_noise.append(np.array(c_t_in)) co_noise_text.append(' ') - img = np.zeros( (y_len,x_len,3) ) - img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(1,1,1)) img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(2,2,2)) img_poly=cv2.fillPoly(img, pts =co_text_header, color=(2,2,2)) - #img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(125,255,125)) - #img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(125,125,0)) - #img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(1,125,255)) - #img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(1,125,0)) img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(3,3,3)) - #img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(1,125,255)) - - #img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(125,0,125)) img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4)) img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5)) - #img_poly=cv2.fillPoly(img, pts =co_table, color=(1,255,255)) - #img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125)) - #img_poly=cv2.fillPoly(img, pts =co_noise, color=(255,0,255)) - - #print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg') - ###try: - ####print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg') - ###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.jpg',img_poly ) - ###except: - ###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg',img_poly ) - return file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\ + + return tree1, root1, file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\ tot_region_ref,x_len, y_len,index_tot_regions, img_poly @@ -1113,3 +1040,24 @@ def make_image_from_bb(width_l, height_l, bb_all): for i in range(bb_all.shape[0]): img_remade[bb_all[i,1]:bb_all[i,1]+bb_all[i,3],bb_all[i,0]:bb_all[i,0]+bb_all[i,2] ] = 1 return img_remade + +def update_list_and_return_first_with_length_bigger_than_one(index_element_to_be_updated, innner_index_pr_pos, pr_list, pos_list,list_inp): + list_inp.pop(index_element_to_be_updated) + if len(pr_list)>0: + list_inp.insert(index_element_to_be_updated, pr_list) + else: + index_element_to_be_updated = index_element_to_be_updated -1 + + list_inp.insert(index_element_to_be_updated+1, [innner_index_pr_pos]) + if len(pos_list)>0: + list_inp.insert(index_element_to_be_updated+2, pos_list) + + len_all_elements = [len(i) for i in list_inp] + list_len_bigger_1 = np.where(np.array(len_all_elements)>1) + list_len_bigger_1 = list_len_bigger_1[0] + + if len(list_len_bigger_1)>0: + early_list_bigger_than_one = list_len_bigger_1[0] + else: + early_list_bigger_than_one = -20 + return list_inp, early_list_bigger_than_one diff --git a/inference.py b/inference.py index 94e318d..73b4ed8 100644 --- a/inference.py +++ b/inference.py @@ -11,13 +11,11 @@ from tensorflow.keras import layers import tensorflow.keras.losses from tensorflow.keras.layers import * from models import * +from gt_gen_utils import * import click import json from tensorflow.python.keras import backend as tensorflow_backend - - - - +import xml.etree.ElementTree as ET with warnings.catch_warnings(): @@ -29,7 +27,7 @@ Tool to load model and predict for given image. """ class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches, save, ground_truth): + def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file): self.image=image self.patches=patches self.save=save @@ -37,6 +35,7 @@ class sbb_predict: self.ground_truth=ground_truth self.task=task self.config_params_model=config_params_model + self.xml_file = xml_file def resize_image(self,img_in,input_height,input_width): return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) @@ -166,7 +165,7 @@ class sbb_predict: ##if self.weights_dir!=None: ##self.model.load_weights(self.weights_dir) - if self.task != 'classification': + if (self.task != 'classification' and self.task != 'reading_order'): self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] @@ -233,6 +232,178 @@ class sbb_predict: index_class = np.argmax(label_p_pred[0]) print("Predicted Class: {}".format(classes_names[str(int(index_class))])) + elif self.task == 'reading_order': + img_height = self.config_params_model['input_height'] + img_width = self.config_params_model['input_width'] + + tree_xml, root_xml, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) + _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) + + img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') + + for j in range(len(cy_main)): + img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 + + co_text_all = co_text_paragraph + co_text_header + id_all_text = id_paragraph + id_header + + ##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] + ##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] + texts_corr_order_index_int = list(np.array(range(len(co_text_all)))) + + min_area = 0 + max_area = 1 + + co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) + + labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8') + for i in range(len(co_text_all)): + img_label = np.zeros((y_len,x_len,3),dtype='uint8') + img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) + labels_con[:,:,i] = img_label[:,:,0] + + img3= np.copy(img_poly) + labels_con = resize_image(labels_con, img_height, img_width) + + img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width) + + img3= resize_image (img3, img_height, img_width) + img3 = img3.astype(np.uint16) + + inference_bs = 1#4 + + input_1= np.zeros( (inference_bs, img_height, img_width,3)) + + + starting_list_of_regions = [] + starting_list_of_regions.append( list(range(labels_con.shape[2])) ) + + index_update = 0 + index_selected = starting_list_of_regions[0] + + scalibility_num = 0 + while index_update>=0: + ij_list = starting_list_of_regions[index_update] + i = ij_list[0] + ij_list.pop(0) + + + pr_list = [] + post_list = [] + + batch_counter = 0 + tot_counter = 1 + + tot_iteration = len(ij_list) + full_bs_ite= tot_iteration//inference_bs + last_bs = tot_iteration % inference_bs + + jbatch_indexer =[] + for j in ij_list: + img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2) + img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2) + + + img2[:,:,0][img3[:,:,0]==5] = 2 + img2[:,:,0][img_header_and_sep[:,:]==1] = 3 + + + + img1[:,:,0][img3[:,:,0]==5] = 2 + img1[:,:,0][img_header_and_sep[:,:]==1] = 3 + + #input_1= np.zeros( (height1, width1,3)) + + + jbatch_indexer.append(j) + + input_1[batch_counter,:,:,0] = img1[:,:,0]/3. + input_1[batch_counter,:,:,2] = img2[:,:,0]/3. + input_1[batch_counter,:,:,1] = img3[:,:,0]/5. + #input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3)) + batch_counter = batch_counter+1 + + #input_1[:,:,0] = img1[:,:,0]/3. + #input_1[:,:,2] = img2[:,:,0]/3. + #input_1[:,:,1] = img3[:,:,0]/5. + + if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs): + y_pr = self.model.predict(input_1 , verbose=0) + scalibility_num = scalibility_num+1 + + if batch_counter==inference_bs: + iteration_batches = inference_bs + else: + iteration_batches = last_bs + for jb in range(iteration_batches): + if y_pr[jb][0]>=0.5: + post_list.append(jbatch_indexer[jb]) + else: + pr_list.append(jbatch_indexer[jb]) + + batch_counter = 0 + jbatch_indexer = [] + + tot_counter = tot_counter+1 + + starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions) + + index_sort = [i[0] for i in starting_list_of_regions ] + + + alltags=[elem.tag for elem in root_xml.iter()] + + + + link=alltags[0].split('}')[0]+'}' + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] + + page_element = root_xml.find(link+'Page') + + """ + ro_subelement = ET.SubElement(page_element, 'ReadingOrder') + #print(page_element, 'page_element') + + #new_element = ET.Element('ReadingOrder') + + new_element_element = ET.Element('OrderedGroup') + new_element_element.set('id', "ro357564684568544579089") + + for index, id_text in enumerate(id_all_text): + new_element_2 = ET.Element('RegionRefIndexed') + new_element_2.set('regionRef', id_all_text[index]) + new_element_2.set('index', str(index_sort[index])) + + new_element_element.append(new_element_2) + + ro_subelement.append(new_element_element) + """ + ##ro_subelement = ET.SubElement(page_element, 'ReadingOrder') + + ro_subelement = ET.Element('ReadingOrder') + + ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup') + ro_subelement2.set('id', "ro357564684568544579089") + + for index, id_text in enumerate(id_all_text): + new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed') + new_element_2.set('regionRef', id_all_text[index]) + new_element_2.set('index', str(index_sort[index])) + + if link+'PrintSpace' in alltags: + page_element.insert(1, ro_subelement) + else: + page_element.insert(0, ro_subelement) + + #page_element[0].append(new_element) + #root_xml.append(new_element) + alltags=[elem.tag for elem in root_xml.iter()] + + ET.register_namespace("",name_space) + tree_xml.write('library2.xml',xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + #tree_xml.write('library2.xml') + else: if self.patches: #def textline_contours(img,input_width,input_height,n_classes,model): @@ -356,7 +527,7 @@ class sbb_predict: def run(self): res=self.predict() - if self.task == 'classification': + if (self.task == 'classification' or self.task == 'reading_order'): pass else: img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) @@ -397,15 +568,20 @@ class sbb_predict: "-gt", help="ground truth directory if you want to see the iou of prediction.", ) -def main(image, model, patches, save, ground_truth): +@click.option( + "--xml_file", + "-xml", + help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", +) +def main(image, model, patches, save, ground_truth, xml_file): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] - if task != 'classification': + if (task != 'classification' and task != 'reading_order'): if not save: print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth) + x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file) x.run() if __name__=="__main__":