mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
inference for reading order
This commit is contained in:
parent
356da4cc53
commit
2e7c69f2ac
2 changed files with 227 additions and 103 deletions
134
gt_gen_utils.py
134
gt_gen_utils.py
|
@ -38,11 +38,8 @@ def filter_contours_area_of_image_tables(image, contours, hierarchy, max_area, m
|
|||
polygon = geometry.Polygon([point[0] for point in c])
|
||||
# area = cv2.contourArea(c)
|
||||
area = polygon.area
|
||||
##print(np.prod(thresh.shape[:2]))
|
||||
# Check that polygon has area greater than minimal area
|
||||
# print(hierarchy[0][jv][3],hierarchy )
|
||||
if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 :
|
||||
# print(c[0][0][1])
|
||||
found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.int32))
|
||||
jv += 1
|
||||
return found_polygons_early
|
||||
|
@ -52,15 +49,12 @@ def filter_contours_area_of_image(image, contours, order_index, max_area, min_ar
|
|||
order_index_filtered = list()
|
||||
#jv = 0
|
||||
for jv, c in enumerate(contours):
|
||||
#print(len(c[0]))
|
||||
c = c[0]
|
||||
if len(c) < 3: # A polygon cannot have less than 3 points
|
||||
continue
|
||||
c_e = [point for point in c]
|
||||
#print(c_e)
|
||||
polygon = geometry.Polygon(c_e)
|
||||
area = polygon.area
|
||||
#print(area,'area')
|
||||
if area >= min_area * np.prod(image.shape[:2]) and area <= max_area * np.prod(image.shape[:2]): # and hierarchy[0][jv][3]==-1 :
|
||||
found_polygons_early.append(np.array([[point] for point in polygon.exterior.coords], dtype=np.uint))
|
||||
order_index_filtered.append(order_index[jv])
|
||||
|
@ -88,12 +82,8 @@ def return_contours_of_interested_region(region_pre_p, pixel, min_area=0.0002):
|
|||
def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y_len, x_len):
|
||||
co_text_eroded = []
|
||||
for con in co_text:
|
||||
#try:
|
||||
img_boundary_in = np.zeros( (y_len,x_len) )
|
||||
img_boundary_in = cv2.fillPoly(img_boundary_in, pts=[con], color=(1, 1, 1))
|
||||
#print('bidiahhhhaaa')
|
||||
|
||||
|
||||
|
||||
#img_boundary_in = cv2.erode(img_boundary_in[:,:], KERNEL, iterations=7)#asiatica
|
||||
if erosion_rate > 0:
|
||||
|
@ -626,8 +616,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
|||
|
||||
|
||||
def find_new_features_of_contours(contours_main):
|
||||
|
||||
#print(contours_main[0][0][:, 0])
|
||||
|
||||
areas_main = np.array([cv2.contourArea(contours_main[j]) for j in range(len(contours_main))])
|
||||
M_main = [cv2.moments(contours_main[j]) for j in range(len(contours_main))]
|
||||
|
@ -658,8 +646,6 @@ def find_new_features_of_contours(contours_main):
|
|||
y_min_main = np.array([np.min(contours_main[j][:, 1]) for j in range(len(contours_main))])
|
||||
y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))])
|
||||
|
||||
# dis_x=np.abs(x_max_main-x_min_main)
|
||||
|
||||
return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin
|
||||
def read_xml(xml_file):
|
||||
file_name = Path(xml_file).stem
|
||||
|
@ -675,13 +661,11 @@ def read_xml(xml_file):
|
|||
y_len=int(jj.attrib['imageHeight'])
|
||||
x_len=int(jj.attrib['imageWidth'])
|
||||
|
||||
|
||||
for jj in root1.iter(link+'RegionRefIndexed'):
|
||||
index_tot_regions.append(jj.attrib['index'])
|
||||
tot_region_ref.append(jj.attrib['regionRef'])
|
||||
|
||||
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
|
||||
#print(region_tags)
|
||||
co_text_paragraph=[]
|
||||
co_text_drop=[]
|
||||
co_text_heading=[]
|
||||
|
@ -698,7 +682,6 @@ def read_xml(xml_file):
|
|||
co_graphic_decoration=[]
|
||||
co_noise=[]
|
||||
|
||||
|
||||
co_text_paragraph_text=[]
|
||||
co_text_drop_text=[]
|
||||
co_text_heading_text=[]
|
||||
|
@ -715,7 +698,6 @@ def read_xml(xml_file):
|
|||
co_graphic_decoration_text=[]
|
||||
co_noise_text=[]
|
||||
|
||||
|
||||
id_paragraph = []
|
||||
id_header = []
|
||||
id_heading = []
|
||||
|
@ -726,14 +708,8 @@ def read_xml(xml_file):
|
|||
for nn in root1.iter(tag):
|
||||
for child2 in nn:
|
||||
tag2 = child2.tag
|
||||
#print(child2.tag)
|
||||
if tag2.endswith('}TextEquiv') or tag2.endswith('}TextEquiv'):
|
||||
#children2 = childtext.getchildren()
|
||||
#rank = child2.find('Unicode').text
|
||||
for childtext2 in child2:
|
||||
#rank = childtext2.find('Unicode').text
|
||||
#if childtext2.tag.endswith('}PlainText') or childtext2.tag.endswith('}PlainText'):
|
||||
#print(childtext2.text)
|
||||
if childtext2.tag.endswith('}Unicode') or childtext2.tag.endswith('}Unicode'):
|
||||
if "type" in nn.attrib and nn.attrib['type']=='drop-capital':
|
||||
co_text_drop_text.append(childtext2.text)
|
||||
|
@ -743,10 +719,10 @@ def read_xml(xml_file):
|
|||
co_text_signature_mark_text.append(childtext2.text)
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='header':
|
||||
co_text_header_text.append(childtext2.text)
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
|
||||
co_text_catch_text.append(childtext2.text)
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='page-number':
|
||||
co_text_page_number_text.append(childtext2.text)
|
||||
###elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
|
||||
###co_text_catch_text.append(childtext2.text)
|
||||
###elif "type" in nn.attrib and nn.attrib['type']=='page-number':
|
||||
###co_text_page_number_text.append(childtext2.text)
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='marginalia':
|
||||
co_text_marginalia_text.append(childtext2.text)
|
||||
else:
|
||||
|
@ -774,7 +750,6 @@ def read_xml(xml_file):
|
|||
|
||||
|
||||
if "type" in nn.attrib and nn.attrib['type']=='drop-capital':
|
||||
#if nn.attrib['type']=='paragraph':
|
||||
|
||||
c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
|
@ -792,27 +767,22 @@ def read_xml(xml_file):
|
|||
c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
|
||||
c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
###elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
|
||||
###c_t_in_catch.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='page-number':
|
||||
###elif "type" in nn.attrib and nn.attrib['type']=='page-number':
|
||||
|
||||
c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
#print(c_t_in_paragraph)
|
||||
###c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='marginalia':
|
||||
id_marginalia.append(nn.attrib['id'])
|
||||
|
||||
c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
#print(c_t_in_paragraph)
|
||||
else:
|
||||
#print(nn.attrib['id'])
|
||||
|
||||
id_paragraph.append(nn.attrib['id'])
|
||||
|
||||
c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
#print(c_t_in_paragraph)
|
||||
|
||||
break
|
||||
else:
|
||||
|
@ -821,7 +791,6 @@ def read_xml(xml_file):
|
|||
|
||||
if vv.tag==link+'Point':
|
||||
if "type" in nn.attrib and nn.attrib['type']=='drop-capital':
|
||||
#if nn.attrib['type']=='paragraph':
|
||||
|
||||
c_t_in_drop.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
|
@ -835,7 +804,6 @@ def read_xml(xml_file):
|
|||
elif "type" in nn.attrib and nn.attrib['type']=='signature-mark':
|
||||
|
||||
c_t_in_signature_mark.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
#print(c_t_in_paragraph)
|
||||
sumi+=1
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='header':
|
||||
id_header.append(nn.attrib['id'])
|
||||
|
@ -843,33 +811,26 @@ def read_xml(xml_file):
|
|||
sumi+=1
|
||||
|
||||
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
|
||||
c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
###elif "type" in nn.attrib and nn.attrib['type']=='catch-word':
|
||||
###c_t_in_catch.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
###sumi+=1
|
||||
|
||||
###elif "type" in nn.attrib and nn.attrib['type']=='page-number':
|
||||
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='page-number':
|
||||
|
||||
c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
#print(c_t_in_paragraph)
|
||||
sumi+=1
|
||||
###c_t_in_page_number.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
###sumi+=1
|
||||
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='marginalia':
|
||||
id_marginalia.append(nn.attrib['id'])
|
||||
|
||||
c_t_in_marginalia.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
#print(c_t_in_paragraph)
|
||||
sumi+=1
|
||||
|
||||
else:
|
||||
id_paragraph.append(nn.attrib['id'])
|
||||
c_t_in_paragraph.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
#print(c_t_in_paragraph)
|
||||
sumi+=1
|
||||
|
||||
#c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
|
||||
#print(vv.tag,'in')
|
||||
elif vv.tag!=link+'Point' and sumi>=1:
|
||||
break
|
||||
|
||||
|
@ -895,7 +856,6 @@ def read_xml(xml_file):
|
|||
|
||||
|
||||
elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'):
|
||||
#print('sth')
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in=[]
|
||||
c_t_in_text_annotation=[]
|
||||
|
@ -907,40 +867,31 @@ def read_xml(xml_file):
|
|||
coords=bool(vv.attrib)
|
||||
if coords:
|
||||
p_h=vv.attrib['points'].split(' ')
|
||||
#c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation':
|
||||
#if nn.attrib['type']=='paragraph':
|
||||
|
||||
c_t_in_text_annotation.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='decoration':
|
||||
|
||||
c_t_in_decoration.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
#print(c_t_in_paragraph)
|
||||
|
||||
else:
|
||||
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||
|
||||
|
||||
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
if vv.tag==link+'Point':
|
||||
|
||||
if "type" in nn.attrib and nn.attrib['type']=='handwritten-annotation':
|
||||
#if nn.attrib['type']=='paragraph':
|
||||
|
||||
c_t_in_text_annotation.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
|
||||
elif "type" in nn.attrib and nn.attrib['type']=='decoration':
|
||||
|
||||
c_t_in_decoration.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
#print(c_t_in_paragraph)
|
||||
sumi+=1
|
||||
|
||||
else:
|
||||
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
|
@ -955,7 +906,6 @@ def read_xml(xml_file):
|
|||
|
||||
|
||||
elif tag.endswith('}ImageRegion') or tag.endswith('}imageregion'):
|
||||
#print('sth')
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in=[]
|
||||
sumi=0
|
||||
|
@ -974,7 +924,6 @@ def read_xml(xml_file):
|
|||
if vv.tag==link+'Point':
|
||||
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
#print(vv.tag,'in')
|
||||
elif vv.tag!=link+'Point' and sumi>=1:
|
||||
break
|
||||
co_img.append(np.array(c_t_in))
|
||||
|
@ -982,7 +931,6 @@ def read_xml(xml_file):
|
|||
|
||||
|
||||
elif tag.endswith('}SeparatorRegion') or tag.endswith('}separatorregion'):
|
||||
#print('sth')
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in=[]
|
||||
sumi=0
|
||||
|
@ -1001,7 +949,6 @@ def read_xml(xml_file):
|
|||
if vv.tag==link+'Point':
|
||||
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
#print(vv.tag,'in')
|
||||
elif vv.tag!=link+'Point' and sumi>=1:
|
||||
break
|
||||
co_sep.append(np.array(c_t_in))
|
||||
|
@ -1009,7 +956,6 @@ def read_xml(xml_file):
|
|||
|
||||
|
||||
elif tag.endswith('}TableRegion') or tag.endswith('}tableregion'):
|
||||
#print('sth')
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in=[]
|
||||
sumi=0
|
||||
|
@ -1028,14 +974,13 @@ def read_xml(xml_file):
|
|||
if vv.tag==link+'Point':
|
||||
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
#print(vv.tag,'in')
|
||||
|
||||
elif vv.tag!=link+'Point' and sumi>=1:
|
||||
break
|
||||
co_table.append(np.array(c_t_in))
|
||||
co_table_text.append(' ')
|
||||
|
||||
elif tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
|
||||
#print('sth')
|
||||
for nn in root1.iter(tag):
|
||||
c_t_in=[]
|
||||
sumi=0
|
||||
|
@ -1054,40 +999,22 @@ def read_xml(xml_file):
|
|||
if vv.tag==link+'Point':
|
||||
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||
sumi+=1
|
||||
#print(vv.tag,'in')
|
||||
|
||||
elif vv.tag!=link+'Point' and sumi>=1:
|
||||
break
|
||||
co_noise.append(np.array(c_t_in))
|
||||
co_noise_text.append(' ')
|
||||
|
||||
|
||||
img = np.zeros( (y_len,x_len,3) )
|
||||
|
||||
img_poly=cv2.fillPoly(img, pts =co_text_paragraph, color=(1,1,1))
|
||||
|
||||
img_poly=cv2.fillPoly(img, pts =co_text_heading, color=(2,2,2))
|
||||
img_poly=cv2.fillPoly(img, pts =co_text_header, color=(2,2,2))
|
||||
#img_poly=cv2.fillPoly(img, pts =co_text_catch, color=(125,255,125))
|
||||
#img_poly=cv2.fillPoly(img, pts =co_text_signature_mark, color=(125,125,0))
|
||||
#img_poly=cv2.fillPoly(img, pts =co_graphic_decoration, color=(1,125,255))
|
||||
#img_poly=cv2.fillPoly(img, pts =co_text_page_number, color=(1,125,0))
|
||||
img_poly=cv2.fillPoly(img, pts =co_text_marginalia, color=(3,3,3))
|
||||
#img_poly=cv2.fillPoly(img, pts =co_text_drop, color=(1,125,255))
|
||||
|
||||
#img_poly=cv2.fillPoly(img, pts =co_graphic_text_annotation, color=(125,0,125))
|
||||
img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4))
|
||||
img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5))
|
||||
#img_poly=cv2.fillPoly(img, pts =co_table, color=(1,255,255))
|
||||
#img_poly=cv2.fillPoly(img, pts =co_graphic, color=(255,125,125))
|
||||
#img_poly=cv2.fillPoly(img, pts =co_noise, color=(255,0,255))
|
||||
|
||||
#print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg')
|
||||
###try:
|
||||
####print('yazdimmm',self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg')
|
||||
###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('-')[1].split('.')[0]+'.jpg',img_poly )
|
||||
###except:
|
||||
###cv2.imwrite(self.output_dir+'/'+self.gt_list[index].split('.')[0]+'.jpg',img_poly )
|
||||
return file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\
|
||||
return tree1, root1, file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\
|
||||
tot_region_ref,x_len, y_len,index_tot_regions, img_poly
|
||||
|
||||
|
||||
|
@ -1113,3 +1040,24 @@ def make_image_from_bb(width_l, height_l, bb_all):
|
|||
for i in range(bb_all.shape[0]):
|
||||
img_remade[bb_all[i,1]:bb_all[i,1]+bb_all[i,3],bb_all[i,0]:bb_all[i,0]+bb_all[i,2] ] = 1
|
||||
return img_remade
|
||||
|
||||
def update_list_and_return_first_with_length_bigger_than_one(index_element_to_be_updated, innner_index_pr_pos, pr_list, pos_list,list_inp):
|
||||
list_inp.pop(index_element_to_be_updated)
|
||||
if len(pr_list)>0:
|
||||
list_inp.insert(index_element_to_be_updated, pr_list)
|
||||
else:
|
||||
index_element_to_be_updated = index_element_to_be_updated -1
|
||||
|
||||
list_inp.insert(index_element_to_be_updated+1, [innner_index_pr_pos])
|
||||
if len(pos_list)>0:
|
||||
list_inp.insert(index_element_to_be_updated+2, pos_list)
|
||||
|
||||
len_all_elements = [len(i) for i in list_inp]
|
||||
list_len_bigger_1 = np.where(np.array(len_all_elements)>1)
|
||||
list_len_bigger_1 = list_len_bigger_1[0]
|
||||
|
||||
if len(list_len_bigger_1)>0:
|
||||
early_list_bigger_than_one = list_len_bigger_1[0]
|
||||
else:
|
||||
early_list_bigger_than_one = -20
|
||||
return list_inp, early_list_bigger_than_one
|
||||
|
|
196
inference.py
196
inference.py
|
@ -11,13 +11,11 @@ from tensorflow.keras import layers
|
|||
import tensorflow.keras.losses
|
||||
from tensorflow.keras.layers import *
|
||||
from models import *
|
||||
from gt_gen_utils import *
|
||||
import click
|
||||
import json
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
|
||||
|
||||
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
|
@ -29,7 +27,7 @@ Tool to load model and predict for given image.
|
|||
"""
|
||||
|
||||
class sbb_predict:
|
||||
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
|
||||
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file):
|
||||
self.image=image
|
||||
self.patches=patches
|
||||
self.save=save
|
||||
|
@ -37,6 +35,7 @@ class sbb_predict:
|
|||
self.ground_truth=ground_truth
|
||||
self.task=task
|
||||
self.config_params_model=config_params_model
|
||||
self.xml_file = xml_file
|
||||
|
||||
def resize_image(self,img_in,input_height,input_width):
|
||||
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
|
||||
|
@ -166,7 +165,7 @@ class sbb_predict:
|
|||
##if self.weights_dir!=None:
|
||||
##self.model.load_weights(self.weights_dir)
|
||||
|
||||
if self.task != 'classification':
|
||||
if (self.task != 'classification' and self.task != 'reading_order'):
|
||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||
|
@ -233,6 +232,178 @@ class sbb_predict:
|
|||
index_class = np.argmax(label_p_pred[0])
|
||||
|
||||
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||
elif self.task == 'reading_order':
|
||||
img_height = self.config_params_model['input_height']
|
||||
img_width = self.config_params_model['input_width']
|
||||
|
||||
tree_xml, root_xml, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file)
|
||||
_, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header)
|
||||
|
||||
img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8')
|
||||
|
||||
for j in range(len(cy_main)):
|
||||
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1
|
||||
|
||||
co_text_all = co_text_paragraph + co_text_header
|
||||
id_all_text = id_paragraph + id_header
|
||||
|
||||
##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ]
|
||||
##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index]
|
||||
texts_corr_order_index_int = list(np.array(range(len(co_text_all))))
|
||||
|
||||
min_area = 0
|
||||
max_area = 1
|
||||
|
||||
co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area)
|
||||
|
||||
labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8')
|
||||
for i in range(len(co_text_all)):
|
||||
img_label = np.zeros((y_len,x_len,3),dtype='uint8')
|
||||
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
|
||||
labels_con[:,:,i] = img_label[:,:,0]
|
||||
|
||||
img3= np.copy(img_poly)
|
||||
labels_con = resize_image(labels_con, img_height, img_width)
|
||||
|
||||
img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width)
|
||||
|
||||
img3= resize_image (img3, img_height, img_width)
|
||||
img3 = img3.astype(np.uint16)
|
||||
|
||||
inference_bs = 1#4
|
||||
|
||||
input_1= np.zeros( (inference_bs, img_height, img_width,3))
|
||||
|
||||
|
||||
starting_list_of_regions = []
|
||||
starting_list_of_regions.append( list(range(labels_con.shape[2])) )
|
||||
|
||||
index_update = 0
|
||||
index_selected = starting_list_of_regions[0]
|
||||
|
||||
scalibility_num = 0
|
||||
while index_update>=0:
|
||||
ij_list = starting_list_of_regions[index_update]
|
||||
i = ij_list[0]
|
||||
ij_list.pop(0)
|
||||
|
||||
|
||||
pr_list = []
|
||||
post_list = []
|
||||
|
||||
batch_counter = 0
|
||||
tot_counter = 1
|
||||
|
||||
tot_iteration = len(ij_list)
|
||||
full_bs_ite= tot_iteration//inference_bs
|
||||
last_bs = tot_iteration % inference_bs
|
||||
|
||||
jbatch_indexer =[]
|
||||
for j in ij_list:
|
||||
img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2)
|
||||
img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
|
||||
img2[:,:,0][img3[:,:,0]==5] = 2
|
||||
img2[:,:,0][img_header_and_sep[:,:]==1] = 3
|
||||
|
||||
|
||||
|
||||
img1[:,:,0][img3[:,:,0]==5] = 2
|
||||
img1[:,:,0][img_header_and_sep[:,:]==1] = 3
|
||||
|
||||
#input_1= np.zeros( (height1, width1,3))
|
||||
|
||||
|
||||
jbatch_indexer.append(j)
|
||||
|
||||
input_1[batch_counter,:,:,0] = img1[:,:,0]/3.
|
||||
input_1[batch_counter,:,:,2] = img2[:,:,0]/3.
|
||||
input_1[batch_counter,:,:,1] = img3[:,:,0]/5.
|
||||
#input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3))
|
||||
batch_counter = batch_counter+1
|
||||
|
||||
#input_1[:,:,0] = img1[:,:,0]/3.
|
||||
#input_1[:,:,2] = img2[:,:,0]/3.
|
||||
#input_1[:,:,1] = img3[:,:,0]/5.
|
||||
|
||||
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
|
||||
y_pr = self.model.predict(input_1 , verbose=0)
|
||||
scalibility_num = scalibility_num+1
|
||||
|
||||
if batch_counter==inference_bs:
|
||||
iteration_batches = inference_bs
|
||||
else:
|
||||
iteration_batches = last_bs
|
||||
for jb in range(iteration_batches):
|
||||
if y_pr[jb][0]>=0.5:
|
||||
post_list.append(jbatch_indexer[jb])
|
||||
else:
|
||||
pr_list.append(jbatch_indexer[jb])
|
||||
|
||||
batch_counter = 0
|
||||
jbatch_indexer = []
|
||||
|
||||
tot_counter = tot_counter+1
|
||||
|
||||
starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions)
|
||||
|
||||
index_sort = [i[0] for i in starting_list_of_regions ]
|
||||
|
||||
|
||||
alltags=[elem.tag for elem in root_xml.iter()]
|
||||
|
||||
|
||||
|
||||
link=alltags[0].split('}')[0]+'}'
|
||||
name_space = alltags[0].split('}')[0]
|
||||
name_space = name_space.split('{')[1]
|
||||
|
||||
page_element = root_xml.find(link+'Page')
|
||||
|
||||
"""
|
||||
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
||||
#print(page_element, 'page_element')
|
||||
|
||||
#new_element = ET.Element('ReadingOrder')
|
||||
|
||||
new_element_element = ET.Element('OrderedGroup')
|
||||
new_element_element.set('id', "ro357564684568544579089")
|
||||
|
||||
for index, id_text in enumerate(id_all_text):
|
||||
new_element_2 = ET.Element('RegionRefIndexed')
|
||||
new_element_2.set('regionRef', id_all_text[index])
|
||||
new_element_2.set('index', str(index_sort[index]))
|
||||
|
||||
new_element_element.append(new_element_2)
|
||||
|
||||
ro_subelement.append(new_element_element)
|
||||
"""
|
||||
##ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
||||
|
||||
ro_subelement = ET.Element('ReadingOrder')
|
||||
|
||||
ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup')
|
||||
ro_subelement2.set('id', "ro357564684568544579089")
|
||||
|
||||
for index, id_text in enumerate(id_all_text):
|
||||
new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed')
|
||||
new_element_2.set('regionRef', id_all_text[index])
|
||||
new_element_2.set('index', str(index_sort[index]))
|
||||
|
||||
if link+'PrintSpace' in alltags:
|
||||
page_element.insert(1, ro_subelement)
|
||||
else:
|
||||
page_element.insert(0, ro_subelement)
|
||||
|
||||
#page_element[0].append(new_element)
|
||||
#root_xml.append(new_element)
|
||||
alltags=[elem.tag for elem in root_xml.iter()]
|
||||
|
||||
ET.register_namespace("",name_space)
|
||||
tree_xml.write('library2.xml',xml_declaration=True,method='xml',encoding="utf8",default_namespace=None)
|
||||
#tree_xml.write('library2.xml')
|
||||
|
||||
else:
|
||||
if self.patches:
|
||||
#def textline_contours(img,input_width,input_height,n_classes,model):
|
||||
|
@ -356,7 +527,7 @@ class sbb_predict:
|
|||
|
||||
def run(self):
|
||||
res=self.predict()
|
||||
if self.task == 'classification':
|
||||
if (self.task == 'classification' or self.task == 'reading_order'):
|
||||
pass
|
||||
else:
|
||||
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
|
||||
|
@ -397,15 +568,20 @@ class sbb_predict:
|
|||
"-gt",
|
||||
help="ground truth directory if you want to see the iou of prediction.",
|
||||
)
|
||||
def main(image, model, patches, save, ground_truth):
|
||||
@click.option(
|
||||
"--xml_file",
|
||||
"-xml",
|
||||
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
|
||||
)
|
||||
def main(image, model, patches, save, ground_truth, xml_file):
|
||||
with open(os.path.join(model,'config.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
task = config_params_model['task']
|
||||
if task != 'classification':
|
||||
if (task != 'classification' and task != 'reading_order'):
|
||||
if not save:
|
||||
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
|
||||
sys.exit(1)
|
||||
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
|
||||
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file)
|
||||
x.run()
|
||||
|
||||
if __name__=="__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue