inference for reading order

pull/18/head
vahidrezanezhad 7 months ago
parent 356da4cc53
commit 2e7c69f2ac

@ -38,11 +38,8 @@ def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, m
polygon = geometry.Polygon([point[0] for point in c])
# area = cv2.contourArea(c)
area = polygon.area
##print(np.prod(thresh.shape[:2]))
# Check that polygon has area greater than minimal area
# print(hierarchy[0][jv][3],hierarchy )
if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 :
# print(c[0][0][1])
found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.int32))
jv += 1
return found_polygons_early
@ -52,15 +49,12 @@ def filter_contours_area_of_image(image, contours, order_index, max_area, min_ar
order_index_filtered = list()
#jv = 0
for jv, c in enumerate(contours):
#print(len(c[0]))
c = c[0]
if len(c) < 3: # A polygon cannot have less than 3 points
continue
c_e = [point for point in c]
#print(c_e)
polygon = geometry.Polygon(c_e)
area = polygon.area
#print(area,'area')
if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 :
found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.uint))
order_index_filtered.append(order_index[jv])
@ -88,12 +82,8 @@ def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002):
def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len):
co_text_eroded = []
for con in co_text:
#try:
img_boundary_in = np.zeros( (y_len,x_len) )
img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1))
#print('bidiahhhhaaa')
#img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica
if erosion_rate > 0:
@ -627,8 +617,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
def find_new_features_of_contours(contours_main):
#print(contours_main[0][0][:, 0])
areas_main = np.array([cv2.contourArea(contours_main[j]) for j in range(len(contours_main))])
M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))]
cx_main = [(M_main[j]["m10"] / (M_main[j]["m00"] + 1e-32)) for j in range(len(M_main))]
@ -658,8 +646,6 @@ def find_new_features_of_contours(contours_main):
y_min_main = np.array([np.min(contours_main[j][:, 1]) for j in range(len(contours_main))])
y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))])
# dis_x=np.abs(x_max_main-x_min_main)
return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin
def read_xml(xml_file):
file_name = Path(xml_file).stem
@ -675,13 +661,11 @@ def read_xml(xml_file):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
for jj in root1.iter(link+'RegionRefIndexed'):
index_tot_regions.append(jj.attrib['index'])
tot_region_ref.append(jj.attrib['regionRef'])
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
#print(region_tags)
co_text_paragraph=[]
co_text_drop=[]
co_text_heading=[]
@ -698,7 +682,6 @@ def read_xml(xml_file):
co_graphic_decoration=[]
co_noise=[]
co_text_paragraph_text=[]
co_text_drop_text=[]
co_text_heading_text=[]
@ -715,7 +698,6 @@ def read_xml(xml_file):
co_graphic_decoration_text=[]
co_noise_text=[]
id_paragraph = []
id_header = []
id_heading = []
@ -726,14 +708,8 @@ def read_xml(xml_file):
for nn in root1.iter(tag):
for child2 in nn:
tag2 = child2.tag
#print(child2.tag)
if tag2.endswith('}TextEquiv') or tag2.endswith('}TextEquiv'):
#children2 = childtext.getchildren()
#rank = child2.find('Unicode').text
for childtext2 in child2:
#rank = childtext2.find('Unicode').text
#if childtext2.tag.endswith('}PlainText') or childtext2.tag.endswith('}PlainText'):
#print(childtext2.text)
if childtext2.tag.endswith('}Unicode') or childtext2.tag.endswith('}Unicode'):
if "type" in nn.attrib and nn.attrib['type']=='drop-capital':
co_text_drop_text.append(childtext2.text)
@ -743,10 +719,10 @@ def read_xml(xml_file):
co_text_signature_mark_text.append(childtext2.text)
elif "type" in nn.attrib and nn.attrib['type']=='header':
co_text_header_text.append(childtext2.text)
elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
co_text_catch_text.append(childtext2.text)
elif "type" in nn.attrib and nn.attrib['type']=='page-number':
co_text_page_number_text.append(childtext2.text)
###elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
###co_text_catch_text.append(childtext2.text)
###elif "type" in nn.attrib and nn.attrib['type']=='page-number':
###co_text_page_number_text.append(childtext2.text)
elif "type" in nn.attrib and nn.attrib['type']=='marginalia':
co_text_marginalia_text.append(childtext2.text)
else:
@ -774,7 +750,6 @@ def read_xml(xml_file):
if "type" in nn.attrib and nn.attrib['type']=='drop-capital':
#if nn.attrib['type']=='paragraph':
c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
@ -792,27 +767,22 @@ def read_xml(xml_file):
c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
###elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
###c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
elif "type" in nn.attrib and nn.attrib['type']=='page-number':
###elif "type" in nn.attrib and nn.attrib['type']=='page-number':
c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
#print(c_t_in_paragraph)
###c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
elif "type" in nn.attrib and nn.attrib['type']=='marginalia':
id_marginalia.append(nn.attrib['id'])
c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
#print(c_t_in_paragraph)
else:
#print(nn.attrib['id'])
id_paragraph.append(nn.attrib['id'])
c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
#print(c_t_in_paragraph)
break
else:
@ -821,7 +791,6 @@ def read_xml(xml_file):
if vv.tag==link+'Point':
if "type" in nn.attrib and nn.attrib['type']=='drop-capital':
#if nn.attrib['type']=='paragraph':
c_t_in_drop.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
@ -835,7 +804,6 @@ def read_xml(xml_file):
elif "type" in nn.attrib and nn.attrib['type']=='signature-mark':
c_t_in_signature_mark.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
#print(c_t_in_paragraph)
sumi+=1
elif "type" in nn.attrib and nn.attrib['type']=='header':
id_header.append(nn.attrib['id'])
@ -843,33 +811,26 @@ def read_xml(xml_file):
sumi+=1
elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
###elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
###c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
###sumi+=1
###elif "type" in nn.attrib and nn.attrib['type']=='page-number':
elif "type" in nn.attrib and nn.attrib['type']=='page-number':
c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
#print(c_t_in_paragraph)
sumi+=1
###c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
###sumi+=1
elif "type" in nn.attrib and nn.attrib['type']=='marginalia':
id_marginalia.append(nn.attrib['id'])
c_t_in_marginalia.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
#print(c_t_in_paragraph)
sumi+=1
else:
id_paragraph.append(nn.attrib['id'])
c_t_in_paragraph.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
#print(c_t_in_paragraph)
sumi+=1
#c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
@ -895,7 +856,6 @@ def read_xml(xml_file):
elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
c_t_in_text_annotation=[]
@ -907,40 +867,31 @@ def read_xml(xml_file):
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
#c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation':
#if nn.attrib['type']=='paragraph':
c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
elif "type" in nn.attrib and nn.attrib['type']=='decoration':
c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
#print(c_t_in_paragraph)
else:
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation':
#if nn.attrib['type']=='paragraph':
c_t_in_text_annotation.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
elif "type" in nn.attrib and nn.attrib['type']=='decoration':
c_t_in_decoration.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
#print(c_t_in_paragraph)
sumi+=1
else:
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
@ -955,7 +906,6 @@ def read_xml(xml_file):
elif tag.endswith('}ImageRegion') or tag.endswith('}imageregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
@ -974,7 +924,6 @@ def read_xml(xml_file):
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_img.append(np.array(c_t_in))
@ -982,7 +931,6 @@ def read_xml(xml_file):
elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
@ -1001,7 +949,6 @@ def read_xml(xml_file):
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_sep.append(np.array(c_t_in))
@ -1009,7 +956,6 @@ def read_xml(xml_file):
elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
@ -1028,14 +974,13 @@ def read_xml(xml_file):
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_table.append(np.array(c_t_in))
co_table_text.append(' ')
elif tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
@ -1054,40 +999,22 @@ def read_xml(xml_file):
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_noise.append(np.array(c_t_in))
co_noise_text.append(' ')
img = np.zeros( (y_len,x_len,3) )
img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(1,1,1))
img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(2,2,2))
img_poly=cv2.fillPoly(img, pts =co_text_header, color=(2,2,2))
#img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(125,255,125))
#img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(125,125,0))
#img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(1,125,255))
#img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(1,125,0))
img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(3,3,3))
#img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(1,125,255))
#img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(125,0,125))
img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4))
img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5))
#img_poly=cv2.fillPoly(img, pts =co_table, color=(1,255,255))
#img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125))
#img_poly=cv2.fillPoly(img, pts =co_noise, color=(255,0,255))
#print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg')
###try:
####print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg')
###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.jpg',img_poly )
###except:
###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg',img_poly )
return file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\
return tree1, root1, file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\
tot_region_ref,x_len, y_len,index_tot_regions, img_poly
@ -1113,3 +1040,24 @@ def make_image_from_bb(width_l, height_l, bb_all):
for i in range(bb_all.shape[0]):
img_remade[bb_all[i,1]:bb_all[i,1]+bb_all[i,3],bb_all[i,0]:bb_all[i,0]+bb_all[i,2] ] = 1
return img_remade
def update_list_and_return_first_with_length_bigger_than_one(index_element_to_be_updated, innner_index_pr_pos, pr_list, pos_list,list_inp):
list_inp.pop(index_element_to_be_updated)
if len(pr_list)>0:
list_inp.insert(index_element_to_be_updated, pr_list)
else:
index_element_to_be_updated = index_element_to_be_updated -1
list_inp.insert(index_element_to_be_updated+1, [innner_index_pr_pos])
if len(pos_list)>0:
list_inp.insert(index_element_to_be_updated+2, pos_list)
len_all_elements = [len(i) for i in list_inp]
list_len_bigger_1 = np.where(np.array(len_all_elements)>1)
list_len_bigger_1 = list_len_bigger_1[0]
if len(list_len_bigger_1)>0:
early_list_bigger_than_one = list_len_bigger_1[0]
else:
early_list_bigger_than_one = -20
return list_inp, early_list_bigger_than_one

@ -11,13 +11,11 @@ from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
from models import *
from gt_gen_utils import *
import click
import json
from tensorflow.python.keras import backend as tensorflow_backend
import xml.etree.ElementTree as ET
with warnings.catch_warnings():
@ -29,7 +27,7 @@ Tool to load model and predict for given image.
"""
class sbb_predict:
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file):
self.image=image
self.patches=patches
self.save=save
@ -37,6 +35,7 @@ class sbb_predict:
self.ground_truth=ground_truth
self.task=task
self.config_params_model=config_params_model
self.xml_file = xml_file
def resize_image(self,img_in,input_height,input_width):
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
@ -166,7 +165,7 @@ class sbb_predict:
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)
if self.task != 'classification':
if (self.task != 'classification' and self.task != 'reading_order'):
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
@ -233,6 +232,178 @@ class sbb_predict:
index_class = np.argmax(label_p_pred[0])
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
elif self.task == 'reading_order':
img_height = self.config_params_model['input_height']
img_width = self.config_params_model['input_width']
tree_xml, root_xml, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file)
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header)
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
for j in range(len(cy_main)):
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1
co_text_all = co_text_paragraph + co_text_header
id_all_text = id_paragraph + id_header
##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
texts_corr_order_index_int = list(np.array(range(len(co_text_all))))
min_area = 0
max_area = 1
co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area)
labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8')
for i in range(len(co_text_all)):
img_label = np.zeros((y_len,x_len,3),dtype='uint8')
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
labels_con[:,:,i] = img_label[:,:,0]
img3= np.copy(img_poly)
labels_con = resize_image(labels_con, img_height, img_width)
img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width)
img3= resize_image (img3, img_height, img_width)
img3 = img3.astype(np.uint16)
inference_bs = 1#4
input_1= np.zeros( (inference_bs, img_height, img_width,3))
starting_list_of_regions = []
starting_list_of_regions.append( list(range(labels_con.shape[2])) )
index_update = 0
index_selected = starting_list_of_regions[0]
scalibility_num = 0
while index_update>=0:
ij_list = starting_list_of_regions[index_update]
i = ij_list[0]
ij_list.pop(0)
pr_list = []
post_list = []
batch_counter = 0
tot_counter = 1
tot_iteration = len(ij_list)
full_bs_ite= tot_iteration//inference_bs
last_bs = tot_iteration % inference_bs
jbatch_indexer =[]
for j in ij_list:
img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2)
img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2)
img2[:,:,0][img3[:,:,0]==5] = 2
img2[:,:,0][img_header_and_sep[:,:]==1] = 3
img1[:,:,0][img3[:,:,0]==5] = 2
img1[:,:,0][img_header_and_sep[:,:]==1] = 3
#input_1= np.zeros( (height1, width1,3))
jbatch_indexer.append(j)
input_1[batch_counter,:,:,0] = img1[:,:,0]/3.
input_1[batch_counter,:,:,2] = img2[:,:,0]/3.
input_1[batch_counter,:,:,1] = img3[:,:,0]/5.
#input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3))
batch_counter = batch_counter+1
#input_1[:,:,0] = img1[:,:,0]/3.
#input_1[:,:,2] = img2[:,:,0]/3.
#input_1[:,:,1] = img3[:,:,0]/5.
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
y_pr = self.model.predict(input_1 , verbose=0)
scalibility_num = scalibility_num+1
if batch_counter==inference_bs:
iteration_batches = inference_bs
else:
iteration_batches = last_bs
for jb in range(iteration_batches):
if y_pr[jb][0]>=0.5:
post_list.append(jbatch_indexer[jb])
else:
pr_list.append(jbatch_indexer[jb])
batch_counter = 0
jbatch_indexer = []
tot_counter = tot_counter+1
starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions)
index_sort = [i[0] for i in starting_list_of_regions ]
alltags=[elem.tag for elem in root_xml.iter()]
link=alltags[0].split('}')[0]+'}'
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
page_element = root_xml.find(link+'Page')
"""
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
#print(page_element, 'page_element')
#new_element = ET.Element('ReadingOrder')
new_element_element = ET.Element('OrderedGroup')
new_element_element.set('id', "ro357564684568544579089")
for index, id_text in enumerate(id_all_text):
new_element_2 = ET.Element('RegionRefIndexed')
new_element_2.set('regionRef', id_all_text[index])
new_element_2.set('index', str(index_sort[index]))
new_element_element.append(new_element_2)
ro_subelement.append(new_element_element)
"""
##ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
ro_subelement = ET.Element('ReadingOrder')
ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup')
ro_subelement2.set('id', "ro357564684568544579089")
for index, id_text in enumerate(id_all_text):
new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed')
new_element_2.set('regionRef', id_all_text[index])
new_element_2.set('index', str(index_sort[index]))
if link+'PrintSpace' in alltags:
page_element.insert(1, ro_subelement)
else:
page_element.insert(0, ro_subelement)
#page_element[0].append(new_element)
#root_xml.append(new_element)
alltags=[elem.tag for elem in root_xml.iter()]
ET.register_namespace("",name_space)
tree_xml.write('library2.xml',xml_declaration=True,method='xml',encoding="utf8",default_namespace=None)
#tree_xml.write('library2.xml')
else:
if self.patches:
#def textline_contours(img,input_width,input_height,n_classes,model):
@ -356,7 +527,7 @@ class sbb_predict:
def run(self):
res=self.predict()
if self.task == 'classification':
if (self.task == 'classification' or self.task == 'reading_order'):
pass
else:
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
@ -397,15 +568,20 @@ class sbb_predict:
"-gt",
help="ground truth directory if you want to see the iou of prediction.",
)
def main(image, model, patches, save, ground_truth):
@click.option(
"--xml_file",
"-xml",
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
)
def main(image, model, patches, save, ground_truth, xml_file):
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = config_params_model['task']
if task != 'classification':
if (task != 'classification' and task != 'reading_order'):
if not save:
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
sys.exit(1)
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file)
x.run()
if __name__=="__main__":

Loading…
Cancel
Save