diff --git a/custom_config_page2label.json b/custom_config_page2label.json index e4c02cb..9116ce3 100644 --- a/custom_config_page2label.json +++ b/custom_config_page2label.json @@ -1,9 +1,8 @@ { -"use_case": "layout", +"use_case": "textline", "textregions":{ "rest_as_paragraph": 1, "header":2 , "heading":2 , "marginalia":3 }, "imageregion":4, "separatorregion":5, "graphicregions" :{"rest_as_decoration":6}, -"artificial_class_on_boundry": ["paragraph"], -"artificial_class_label":7 +"columns_width":{"1":1000, "2":1300, "3":1600, "4":2000, "5":2300, "6":2500} } diff --git a/generate_gt_for_training.py b/generate_gt_for_training.py index cf2b2a6..752090c 100644 --- a/generate_gt_for_training.py +++ b/generate_gt_for_training.py @@ -14,10 +14,22 @@ def main(): help="directory of GT page-xml files", type=click.Path(exists=True, file_okay=False), ) +@click.option( + "--dir_images", + "-di", + help="directory of org images. If print space cropping or scaling is needed for labels it would be great to provide the original images to apply the same function on them. So if -ps is not set true or in config files no columns_width key is given this argumnet can be ignored. File stems in this directory should be the same as those in dir_xml.", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out_images", + "-doi", + help="directory where the output org images after undergoing a process (like print space cropping or scaling) will be written.", + type=click.Path(exists=True, file_okay=False), +) @click.option( "--dir_out", "-do", - help="directory where ground truth images would be written", + help="directory where ground truth label images would be written", type=click.Path(exists=True, file_okay=False), ) @@ -33,8 +45,14 @@ def main(): "-to", help="this defines how output should be. A 2d image array or a 3d image array encoded with RGB color. Just pass 2d or 3d. The file will be saved one directory up. 2D image array is 3d but only information of one channel would be enough since all channels have the same values.", ) +@click.option( + "--printspace", + "-ps", + is_flag=True, + help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.", +) -def pagexml2label(dir_xml,dir_out,type_output,config): +def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): if config: with open(config) as f: config_params = json.load(f) @@ -42,7 +60,7 @@ def pagexml2label(dir_xml,dir_out,type_output,config): print("passed") config_params = None gt_list = get_content_of_dir(dir_xml) - get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params) + get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images) @main.command() @click.option( @@ -181,7 +199,7 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i for i in range(len(texts_corr_order_index_int)): for j in range(len(texts_corr_order_index_int)): if i!=j: - input_matrix = np.zeros((input_height,input_width,3)).astype(np.int8) + input_multi_visual_modal = np.zeros((input_height,input_width,3)).astype(np.int8) final_f_name = f_name+'_'+str(indexer+indexer_start) order_class_condition = texts_corr_order_index_int[i]-texts_corr_order_index_int[j] if order_class_condition<0: @@ -189,13 +207,13 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i else: class_type = 0 - input_matrix[:,:,0] = resize_image(labels_con[:,:,i], input_height, input_width) - input_matrix[:,:,1] = resize_image(img_poly[:,:,0], input_height, input_width) - input_matrix[:,:,2] = resize_image(labels_con[:,:,j], input_height, input_width) + input_multi_visual_modal[:,:,0] = resize_image(labels_con[:,:,i], input_height, input_width) + input_multi_visual_modal[:,:,1] = resize_image(img_poly[:,:,0], input_height, input_width) + input_multi_visual_modal[:,:,2] = resize_image(labels_con[:,:,j], input_height, input_width) np.save(os.path.join(dir_out_classes,final_f_name+'.npy' ), class_type) - cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_matrix) + cv2.imwrite(os.path.join(dir_out_modal_image,final_f_name+'.png' ), input_multi_visual_modal) indexer = indexer+1 diff --git a/gt_gen_utils.py b/gt_gen_utils.py index debaf15..d3e95e8 100644 --- a/gt_gen_utils.py +++ b/gt_gen_utils.py @@ -115,11 +115,15 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y img_boundary[:,:][boundary[:,:]==1] =1 return co_text_eroded, img_boundary -def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params): +def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): """ Reading the page xml files and write the ground truth images into given output directory. """ ## to do: add footnote to text regions + + if dir_images: + ls_org_imgs = os.listdir(dir_images) + ls_org_imgs_stem = [item.split('.')[0] for item in ls_org_imgs] for index in tqdm(range(len(gt_list))): #try: tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding = 'iso-8859-5')) @@ -133,6 +137,72 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ y_len=int(jj.attrib['imageHeight']) x_len=int(jj.attrib['imageWidth']) + if 'columns_width' in list(config_params.keys()): + columns_width_dict = config_params['columns_width'] + metadata_element = root1.find(link+'Metadata') + comment_is_sub_element = False + for child in metadata_element: + tag2 = child.tag + if tag2.endswith('}Comments') or tag2.endswith('}comments'): + text_comments = child.text + num_col = int(text_comments.split('num_col')[1]) + comment_is_sub_element = True + if not comment_is_sub_element: + num_col = None + + if num_col: + x_new = columns_width_dict[str(num_col)] + y_new = int ( x_new * (y_len / float(x_len)) ) + + if printspace: + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) + co_use_case = [] + + for tag in region_tags: + tag_endings = ['}PrintSpace','}Border'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + img = np.zeros((y_len, x_len, 3)) + + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + + img_poly = img_poly.astype(np.uint8) + + imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + + cnt = contours[np.argmax(cnt_size)] + + x, y, w, h = cv2.boundingRect(cnt) + bb_xywh = [x, y, w, h] + + if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): keys = list(config_params.keys()) if "artificial_class_label" in keys: @@ -186,7 +256,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ co_use_case.append(np.array(c_t_in)) - if "artificial_class_label" in keys: img_boundary = np.zeros((y_len, x_len)) erosion_rate = 1 @@ -205,12 +274,32 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0] img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1] img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2] + + + if printspace and config_params['use_case']!='printspace': + img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': + img_poly = resize_image(img_poly, y_new, x_new) try: - cv2.imwrite(output_dir + '/' + gt_list[index].split('-')[1].split('.')[0] + '.png', - img_poly) + xml_file_stem = gt_list[index].split('-')[1].split('.')[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) except: - cv2.imwrite(output_dir + '/' + gt_list[index].split('.')[0] + '.png', img_poly) + xml_file_stem = gt_list[index].split('.')[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + + if dir_images: + org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] + img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + + if printspace and config_params['use_case']!='printspace': + img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': + img_org = resize_image(img_org, y_new, x_new) + + cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) if config_file and config_params['use_case']=='layout': @@ -616,11 +705,31 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ + if printspace: + img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col: + img_poly = resize_image(img_poly, y_new, x_new) - try: - cv2.imwrite(output_dir+'/'+gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly ) + try: + xml_file_stem = gt_list[index].split('-')[1].split('.')[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) except: - cv2.imwrite(output_dir+'/'+gt_list[index].split('.')[0]+'.png',img_poly ) + xml_file_stem = gt_list[index].split('.')[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + + + if dir_images: + org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] + img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + + if printspace: + img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col: + img_org = resize_image(img_org, y_new, x_new) + + cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)