inference for reading order

This commit is contained in:
vahidrezanezhad 2024-05-28 16:48:51 +02:00
parent 356da4cc53
commit 2e7c69f2ac
2 changed files with 227 additions and 103 deletions

View file

@ -11,13 +11,11 @@ from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
from models import *
from gt_gen_utils import *
import click
import json
from tensorflow.python.keras import backend as tensorflow_backend
import xml.etree.ElementTree as ET
with warnings.catch_warnings():
@ -29,7 +27,7 @@ Tool to load model and predict for given image.
"""
class sbb_predict:
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file):
self.image=image
self.patches=patches
self.save=save
@ -37,6 +35,7 @@ class sbb_predict:
self.ground_truth=ground_truth
self.task=task
self.config_params_model=config_params_model
self.xml_file = xml_file
def resize_image(self,img_in,input_height,input_width):
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
@ -166,7 +165,7 @@ class sbb_predict:
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)
if self.task != 'classification':
if (self.task != 'classification' and self.task != 'reading_order'):
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
@ -233,6 +232,178 @@ class sbb_predict:
index_class = np.argmax(label_p_pred[0])
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
elif self.task == 'reading_order':
img_height = self.config_params_model['input_height']
img_width = self.config_params_model['input_width']
tree_xml, root_xml, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file)
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header)
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
for j in range(len(cy_main)):
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1
co_text_all = co_text_paragraph + co_text_header
id_all_text = id_paragraph + id_header
##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
texts_corr_order_index_int = list(np.array(range(len(co_text_all))))
min_area = 0
max_area = 1
co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area)
labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8')
for i in range(len(co_text_all)):
img_label = np.zeros((y_len,x_len,3),dtype='uint8')
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
labels_con[:,:,i] = img_label[:,:,0]
img3= np.copy(img_poly)
labels_con = resize_image(labels_con, img_height, img_width)
img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width)
img3= resize_image (img3, img_height, img_width)
img3 = img3.astype(np.uint16)
inference_bs = 1#4
input_1= np.zeros( (inference_bs, img_height, img_width,3))
starting_list_of_regions = []
starting_list_of_regions.append( list(range(labels_con.shape[2])) )
index_update = 0
index_selected = starting_list_of_regions[0]
scalibility_num = 0
while index_update>=0:
ij_list = starting_list_of_regions[index_update]
i = ij_list[0]
ij_list.pop(0)
pr_list = []
post_list = []
batch_counter = 0
tot_counter = 1
tot_iteration = len(ij_list)
full_bs_ite= tot_iteration//inference_bs
last_bs = tot_iteration % inference_bs
jbatch_indexer =[]
for j in ij_list:
img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2)
img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2)
img2[:,:,0][img3[:,:,0]==5] = 2
img2[:,:,0][img_header_and_sep[:,:]==1] = 3
img1[:,:,0][img3[:,:,0]==5] = 2
img1[:,:,0][img_header_and_sep[:,:]==1] = 3
#input_1= np.zeros( (height1, width1,3))
jbatch_indexer.append(j)
input_1[batch_counter,:,:,0] = img1[:,:,0]/3.
input_1[batch_counter,:,:,2] = img2[:,:,0]/3.
input_1[batch_counter,:,:,1] = img3[:,:,0]/5.
#input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3))
batch_counter = batch_counter+1
#input_1[:,:,0] = img1[:,:,0]/3.
#input_1[:,:,2] = img2[:,:,0]/3.
#input_1[:,:,1] = img3[:,:,0]/5.
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
y_pr = self.model.predict(input_1 , verbose=0)
scalibility_num = scalibility_num+1
if batch_counter==inference_bs:
iteration_batches = inference_bs
else:
iteration_batches = last_bs
for jb in range(iteration_batches):
if y_pr[jb][0]>=0.5:
post_list.append(jbatch_indexer[jb])
else:
pr_list.append(jbatch_indexer[jb])
batch_counter = 0
jbatch_indexer = []
tot_counter = tot_counter+1
starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions)
index_sort = [i[0] for i in starting_list_of_regions ]
alltags=[elem.tag for elem in root_xml.iter()]
link=alltags[0].split('}')[0]+'}'
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
page_element = root_xml.find(link+'Page')
"""
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
#print(page_element, 'page_element')
#new_element = ET.Element('ReadingOrder')
new_element_element = ET.Element('OrderedGroup')
new_element_element.set('id', "ro357564684568544579089")
for index, id_text in enumerate(id_all_text):
new_element_2 = ET.Element('RegionRefIndexed')
new_element_2.set('regionRef', id_all_text[index])
new_element_2.set('index', str(index_sort[index]))
new_element_element.append(new_element_2)
ro_subelement.append(new_element_element)
"""
##ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
ro_subelement = ET.Element('ReadingOrder')
ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup')
ro_subelement2.set('id', "ro357564684568544579089")
for index, id_text in enumerate(id_all_text):
new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed')
new_element_2.set('regionRef', id_all_text[index])
new_element_2.set('index', str(index_sort[index]))
if link+'PrintSpace' in alltags:
page_element.insert(1, ro_subelement)
else:
page_element.insert(0, ro_subelement)
#page_element[0].append(new_element)
#root_xml.append(new_element)
alltags=[elem.tag for elem in root_xml.iter()]
ET.register_namespace("",name_space)
tree_xml.write('library2.xml',xml_declaration=True,method='xml',encoding="utf8",default_namespace=None)
#tree_xml.write('library2.xml')
else:
if self.patches:
#def textline_contours(img,input_width,input_height,n_classes,model):
@ -356,7 +527,7 @@ class sbb_predict:
def run(self):
res=self.predict()
if self.task == 'classification':
if (self.task == 'classification' or self.task == 'reading_order'):
pass
else:
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
@ -397,15 +568,20 @@ class sbb_predict:
"-gt",
help="ground truth directory if you want to see the iou of prediction.",
)
def main(image, model, patches, save, ground_truth):
@click.option(
"--xml_file",
"-xml",
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
)
def main(image, model, patches, save, ground_truth, xml_file):
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = config_params_model['task']
if task != 'classification':
if (task != 'classification' and task != 'reading_order'):
if not save:
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
sys.exit(1)
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file)
x.run()
if __name__=="__main__":