mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
inference for reading order
This commit is contained in:
parent
356da4cc53
commit
2e7c69f2ac
2 changed files with 227 additions and 103 deletions
196
inference.py
196
inference.py
|
@ -11,13 +11,11 @@ from tensorflow.keras import layers
|
|||
import tensorflow.keras.losses
|
||||
from tensorflow.keras.layers import *
|
||||
from models import *
|
||||
from gt_gen_utils import *
|
||||
import click
|
||||
import json
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
|
||||
|
||||
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
|
@ -29,7 +27,7 @@ Tool to load model and predict for given image.
|
|||
"""
|
||||
|
||||
class sbb_predict:
|
||||
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
|
||||
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file):
|
||||
self.image=image
|
||||
self.patches=patches
|
||||
self.save=save
|
||||
|
@ -37,6 +35,7 @@ class sbb_predict:
|
|||
self.ground_truth=ground_truth
|
||||
self.task=task
|
||||
self.config_params_model=config_params_model
|
||||
self.xml_file = xml_file
|
||||
|
||||
def resize_image(self,img_in,input_height,input_width):
|
||||
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
|
||||
|
@ -166,7 +165,7 @@ class sbb_predict:
|
|||
##if self.weights_dir!=None:
|
||||
##self.model.load_weights(self.weights_dir)
|
||||
|
||||
if self.task != 'classification':
|
||||
if (self.task != 'classification' and self.task != 'reading_order'):
|
||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||
|
@ -233,6 +232,178 @@ class sbb_predict:
|
|||
index_class = np.argmax(label_p_pred[0])
|
||||
|
||||
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||
elif self.task == 'reading_order':
|
||||
img_height = self.config_params_model['input_height']
|
||||
img_width = self.config_params_model['input_width']
|
||||
|
||||
tree_xml, root_xml, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file)
|
||||
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header)
|
||||
|
||||
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
|
||||
|
||||
for j in range(len(cy_main)):
|
||||
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1
|
||||
|
||||
co_text_all = co_text_paragraph + co_text_header
|
||||
id_all_text = id_paragraph + id_header
|
||||
|
||||
##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
|
||||
##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
|
||||
texts_corr_order_index_int = list(np.array(range(len(co_text_all))))
|
||||
|
||||
min_area = 0
|
||||
max_area = 1
|
||||
|
||||
co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area)
|
||||
|
||||
labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8')
|
||||
for i in range(len(co_text_all)):
|
||||
img_label = np.zeros((y_len,x_len,3),dtype='uint8')
|
||||
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
|
||||
labels_con[:,:,i] = img_label[:,:,0]
|
||||
|
||||
img3= np.copy(img_poly)
|
||||
labels_con = resize_image(labels_con, img_height, img_width)
|
||||
|
||||
img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width)
|
||||
|
||||
img3= resize_image (img3, img_height, img_width)
|
||||
img3 = img3.astype(np.uint16)
|
||||
|
||||
inference_bs = 1#4
|
||||
|
||||
input_1= np.zeros( (inference_bs, img_height, img_width,3))
|
||||
|
||||
|
||||
starting_list_of_regions = []
|
||||
starting_list_of_regions.append( list(range(labels_con.shape[2])) )
|
||||
|
||||
index_update = 0
|
||||
index_selected = starting_list_of_regions[0]
|
||||
|
||||
scalibility_num = 0
|
||||
while index_update>=0:
|
||||
ij_list = starting_list_of_regions[index_update]
|
||||
i = ij_list[0]
|
||||
ij_list.pop(0)
|
||||
|
||||
|
||||
pr_list = []
|
||||
post_list = []
|
||||
|
||||
batch_counter = 0
|
||||
tot_counter = 1
|
||||
|
||||
tot_iteration = len(ij_list)
|
||||
full_bs_ite= tot_iteration//inference_bs
|
||||
last_bs = tot_iteration % inference_bs
|
||||
|
||||
jbatch_indexer =[]
|
||||
for j in ij_list:
|
||||
img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2)
|
||||
img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
|
||||
img2[:,:,0][img3[:,:,0]==5] = 2
|
||||
img2[:,:,0][img_header_and_sep[:,:]==1] = 3
|
||||
|
||||
|
||||
|
||||
img1[:,:,0][img3[:,:,0]==5] = 2
|
||||
img1[:,:,0][img_header_and_sep[:,:]==1] = 3
|
||||
|
||||
#input_1= np.zeros( (height1, width1,3))
|
||||
|
||||
|
||||
jbatch_indexer.append(j)
|
||||
|
||||
input_1[batch_counter,:,:,0] = img1[:,:,0]/3.
|
||||
input_1[batch_counter,:,:,2] = img2[:,:,0]/3.
|
||||
input_1[batch_counter,:,:,1] = img3[:,:,0]/5.
|
||||
#input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3))
|
||||
batch_counter = batch_counter+1
|
||||
|
||||
#input_1[:,:,0] = img1[:,:,0]/3.
|
||||
#input_1[:,:,2] = img2[:,:,0]/3.
|
||||
#input_1[:,:,1] = img3[:,:,0]/5.
|
||||
|
||||
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
|
||||
y_pr = self.model.predict(input_1 , verbose=0)
|
||||
scalibility_num = scalibility_num+1
|
||||
|
||||
if batch_counter==inference_bs:
|
||||
iteration_batches = inference_bs
|
||||
else:
|
||||
iteration_batches = last_bs
|
||||
for jb in range(iteration_batches):
|
||||
if y_pr[jb][0]>=0.5:
|
||||
post_list.append(jbatch_indexer[jb])
|
||||
else:
|
||||
pr_list.append(jbatch_indexer[jb])
|
||||
|
||||
batch_counter = 0
|
||||
jbatch_indexer = []
|
||||
|
||||
tot_counter = tot_counter+1
|
||||
|
||||
starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions)
|
||||
|
||||
index_sort = [i[0] for i in starting_list_of_regions ]
|
||||
|
||||
|
||||
alltags=[elem.tag for elem in root_xml.iter()]
|
||||
|
||||
|
||||
|
||||
link=alltags[0].split('}')[0]+'}'
|
||||
name_space = alltags[0].split('}')[0]
|
||||
name_space = name_space.split('{')[1]
|
||||
|
||||
page_element = root_xml.find(link+'Page')
|
||||
|
||||
"""
|
||||
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
||||
#print(page_element, 'page_element')
|
||||
|
||||
#new_element = ET.Element('ReadingOrder')
|
||||
|
||||
new_element_element = ET.Element('OrderedGroup')
|
||||
new_element_element.set('id', "ro357564684568544579089")
|
||||
|
||||
for index, id_text in enumerate(id_all_text):
|
||||
new_element_2 = ET.Element('RegionRefIndexed')
|
||||
new_element_2.set('regionRef', id_all_text[index])
|
||||
new_element_2.set('index', str(index_sort[index]))
|
||||
|
||||
new_element_element.append(new_element_2)
|
||||
|
||||
ro_subelement.append(new_element_element)
|
||||
"""
|
||||
##ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
||||
|
||||
ro_subelement = ET.Element('ReadingOrder')
|
||||
|
||||
ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup')
|
||||
ro_subelement2.set('id', "ro357564684568544579089")
|
||||
|
||||
for index, id_text in enumerate(id_all_text):
|
||||
new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed')
|
||||
new_element_2.set('regionRef', id_all_text[index])
|
||||
new_element_2.set('index', str(index_sort[index]))
|
||||
|
||||
if link+'PrintSpace' in alltags:
|
||||
page_element.insert(1, ro_subelement)
|
||||
else:
|
||||
page_element.insert(0, ro_subelement)
|
||||
|
||||
#page_element[0].append(new_element)
|
||||
#root_xml.append(new_element)
|
||||
alltags=[elem.tag for elem in root_xml.iter()]
|
||||
|
||||
ET.register_namespace("",name_space)
|
||||
tree_xml.write('library2.xml',xml_declaration=True,method='xml',encoding="utf8",default_namespace=None)
|
||||
#tree_xml.write('library2.xml')
|
||||
|
||||
else:
|
||||
if self.patches:
|
||||
#def textline_contours(img,input_width,input_height,n_classes,model):
|
||||
|
@ -356,7 +527,7 @@ class sbb_predict:
|
|||
|
||||
def run(self):
|
||||
res=self.predict()
|
||||
if self.task == 'classification':
|
||||
if (self.task == 'classification' or self.task == 'reading_order'):
|
||||
pass
|
||||
else:
|
||||
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
|
||||
|
@ -397,15 +568,20 @@ class sbb_predict:
|
|||
"-gt",
|
||||
help="ground truth directory if you want to see the iou of prediction.",
|
||||
)
|
||||
def main(image, model, patches, save, ground_truth):
|
||||
@click.option(
|
||||
"--xml_file",
|
||||
"-xml",
|
||||
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
|
||||
)
|
||||
def main(image, model, patches, save, ground_truth, xml_file):
|
||||
with open(os.path.join(model,'config.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
task = config_params_model['task']
|
||||
if task != 'classification':
|
||||
if (task != 'classification' and task != 'reading_order'):
|
||||
if not save:
|
||||
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
|
||||
sys.exit(1)
|
||||
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
|
||||
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file)
|
||||
x.run()
|
||||
|
||||
if __name__=="__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue