diff --git a/gt_gen_utils.py b/gt_gen_utils.py index 0286ac7..8f72fb8 100644 --- a/gt_gen_utils.py +++ b/gt_gen_utils.py @@ -664,6 +664,58 @@ def read_xml(xml_file): for jj in root1.iter(link+'RegionRefIndexed'): index_tot_regions.append(jj.attrib['index']) tot_region_ref.append(jj.attrib['regionRef']) + + if (link+'PrintSpace' in alltags) or (link+'Border' in alltags): + co_printspace = [] + if link+'PrintSpace' in alltags: + region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + elif link+'Border' in alltags: + region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')]) + + for tag in region_tags_printspace: + if link+'PrintSpace' in alltags: + tag_endings_printspace = ['}PrintSpace','}printspace'] + elif link+'Border' in alltags: + tag_endings_printspace = ['}Border','}border'] + + if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_printspace.append(np.array(c_t_in)) + img_printspace = np.zeros( (y_len,x_len,3) ) + img_printspace=cv2.fillPoly(img_printspace, pts =co_printspace, color=(1,1,1)) + img_printspace = img_printspace.astype(np.uint8) + + imgray = cv2.cvtColor(img_printspace, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + cnt = contours[np.argmax(cnt_size)] + x, y, w, h = cv2.boundingRect(cnt) + + bb_coord_printspace = [x, y, w, h] + + else: + bb_coord_printspace = None + region_tags=np.unique([x for x in alltags if x.endswith('Region')]) co_text_paragraph=[] @@ -754,7 +806,7 @@ def read_xml(xml_file): c_t_in_drop.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) elif "type" in nn.attrib and nn.attrib['type']=='heading': - id_heading.append(nn.attrib['id']) + ##id_heading.append(nn.attrib['id']) c_t_in_heading.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) @@ -763,7 +815,7 @@ def read_xml(xml_file): c_t_in_signature_mark.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) #print(c_t_in_paragraph) elif "type" in nn.attrib and nn.attrib['type']=='header': - id_header.append(nn.attrib['id']) + #id_header.append(nn.attrib['id']) c_t_in_header.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) @@ -776,11 +828,11 @@ def read_xml(xml_file): ###c_t_in_page_number.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) elif "type" in nn.attrib and nn.attrib['type']=='marginalia': - id_marginalia.append(nn.attrib['id']) + #id_marginalia.append(nn.attrib['id']) c_t_in_marginalia.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) else: - id_paragraph.append(nn.attrib['id']) + #id_paragraph.append(nn.attrib['id']) c_t_in_paragraph.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) @@ -796,7 +848,7 @@ def read_xml(xml_file): sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='heading': - id_heading.append(nn.attrib['id']) + #id_heading.append(nn.attrib['id']) c_t_in_heading.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -806,7 +858,7 @@ def read_xml(xml_file): c_t_in_signature_mark.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='header': - id_header.append(nn.attrib['id']) + #id_header.append(nn.attrib['id']) c_t_in_header.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -821,13 +873,13 @@ def read_xml(xml_file): ###sumi+=1 elif "type" in nn.attrib and nn.attrib['type']=='marginalia': - id_marginalia.append(nn.attrib['id']) + #id_marginalia.append(nn.attrib['id']) c_t_in_marginalia.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 else: - id_paragraph.append(nn.attrib['id']) + #id_paragraph.append(nn.attrib['id']) c_t_in_paragraph.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) sumi+=1 @@ -838,11 +890,14 @@ def read_xml(xml_file): co_text_drop.append(np.array(c_t_in_drop)) if len(c_t_in_paragraph)>0: co_text_paragraph.append(np.array(c_t_in_paragraph)) + id_paragraph.append(nn.attrib['id']) if len(c_t_in_heading)>0: co_text_heading.append(np.array(c_t_in_heading)) + id_heading.append(nn.attrib['id']) if len(c_t_in_header)>0: co_text_header.append(np.array(c_t_in_header)) + id_header.append(nn.attrib['id']) if len(c_t_in_page_number)>0: co_text_page_number.append(np.array(c_t_in_page_number)) if len(c_t_in_catch)>0: @@ -853,6 +908,7 @@ def read_xml(xml_file): if len(c_t_in_marginalia)>0: co_text_marginalia.append(np.array(c_t_in_marginalia)) + id_marginalia.append(nn.attrib['id']) elif tag.endswith('}GraphicRegion') or tag.endswith('}graphicregion'): @@ -1014,7 +1070,7 @@ def read_xml(xml_file): img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4)) img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5)) - return tree1, root1, file_name, id_paragraph, id_header,co_text_paragraph, co_text_header,\ + return tree1, root1, bb_coord_printspace, file_name, id_paragraph, id_header+id_heading, co_text_paragraph, co_text_header+co_text_heading,\ tot_region_ref,x_len, y_len,index_tot_regions, img_poly diff --git a/inference.py b/inference.py index 73b4ed8..28445e8 100644 --- a/inference.py +++ b/inference.py @@ -16,6 +16,7 @@ import click import json from tensorflow.python.keras import backend as tensorflow_backend import xml.etree.ElementTree as ET +import matplotlib.pyplot as plt with warnings.catch_warnings(): @@ -27,7 +28,7 @@ Tool to load model and predict for given image. """ class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file): + def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file, out): self.image=image self.patches=patches self.save=save @@ -36,6 +37,7 @@ class sbb_predict: self.task=task self.config_params_model=config_params_model self.xml_file = xml_file + self.out = out def resize_image(self,img_in,input_height,input_width): return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) @@ -236,16 +238,18 @@ class sbb_predict: img_height = self.config_params_model['input_height'] img_width = self.config_params_model['input_width'] - tree_xml, root_xml, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) + tree_xml, root_xml, bb_coord_printspace, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') + for j in range(len(cy_main)): img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 co_text_all = co_text_paragraph + co_text_header id_all_text = id_paragraph + id_header + ##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] ##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] @@ -253,8 +257,9 @@ class sbb_predict: min_area = 0 max_area = 1 + - co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) + ##co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8') for i in range(len(co_text_all)): @@ -262,6 +267,18 @@ class sbb_predict: img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) labels_con[:,:,i] = img_label[:,:,0] + if bb_coord_printspace: + #bb_coord_printspace[x,y,w,h,_,_] + x = bb_coord_printspace[0] + y = bb_coord_printspace[1] + w = bb_coord_printspace[2] + h = bb_coord_printspace[3] + labels_con = labels_con[y:y+h, x:x+w, :] + img_poly = img_poly[y:y+h, x:x+w, :] + img_header_and_sep = img_header_and_sep[y:y+h, x:x+w] + + + img3= np.copy(img_poly) labels_con = resize_image(labels_con, img_height, img_width) @@ -347,9 +364,11 @@ class sbb_predict: tot_counter = tot_counter+1 starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions) - + + index_sort = [i[0] for i in starting_list_of_regions ] + id_all_text = np.array(id_all_text)[index_sort] alltags=[elem.tag for elem in root_xml.iter()] @@ -389,19 +408,17 @@ class sbb_predict: for index, id_text in enumerate(id_all_text): new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed') new_element_2.set('regionRef', id_all_text[index]) - new_element_2.set('index', str(index_sort[index])) + new_element_2.set('index', str(index)) - if link+'PrintSpace' in alltags: + if (link+'PrintSpace' in alltags) or (link+'Border' in alltags): page_element.insert(1, ro_subelement) else: page_element.insert(0, ro_subelement) - #page_element[0].append(new_element) - #root_xml.append(new_element) alltags=[elem.tag for elem in root_xml.iter()] ET.register_namespace("",name_space) - tree_xml.write('library2.xml',xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + tree_xml.write(os.path.join(self.out, file_name+'.xml'),xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) #tree_xml.write('library2.xml') else: @@ -545,6 +562,12 @@ class sbb_predict: help="image filename", type=click.Path(exists=True, dir_okay=False), ) +@click.option( + "--out", + "-o", + help="output directory where xml with detected reading order will be written.", + type=click.Path(exists=True, file_okay=False), +) @click.option( "--patches/--no-patches", "-p/-nop", @@ -573,7 +596,7 @@ class sbb_predict: "-xml", help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", ) -def main(image, model, patches, save, ground_truth, xml_file): +def main(image, model, patches, save, ground_truth, xml_file, out): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] @@ -581,7 +604,7 @@ def main(image, model, patches, save, ground_truth, xml_file): if not save: print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file) + x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file, out) x.run() if __name__=="__main__":