diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index d819944..a848b65 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -521,9 +521,9 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs): img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) - co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) + co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) - added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, img) + added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, co_map, co_music, img) cv2.imwrite(os.path.join(dir_out, f_name+'.png'), added_image) except: diff --git a/src/eynollah/training/gt_gen_utils.py b/src/eynollah/training/gt_gen_utils.py index f21ee53..e085fa0 100644 --- a/src/eynollah/training/gt_gen_utils.py +++ b/src/eynollah/training/gt_gen_utils.py @@ -18,7 +18,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") -def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img): +def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, co_music, img): alpha = 0.5 blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 @@ -32,6 +32,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_ col_marginal = (106, 90, 205) col_table = (0, 90, 205) col_map = (90, 90, 205) + col_music = (90, 90, 0) if len(co_image)>0: cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour @@ -59,6 +60,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_ if len(co_map)>0: cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour + + if len(co_music)>0: + cv2.drawContours(blank_image, co_music, -1, col_music, thickness=cv2.FILLED) # Fill the contour img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) @@ -389,6 +393,7 @@ def get_layout_contours_for_visualization(xml_file): co_img=[] co_table=[] co_map=[] + co_music=[] co_noise=[] types_text = [] @@ -630,6 +635,31 @@ def get_layout_contours_for_visualization(xml_file): elif vv.tag!=link+'Point' and sumi>=1: break co_map.append(np.array(c_t_in)) + + if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_music.append(np.array(c_t_in)) if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): @@ -656,7 +686,7 @@ def get_layout_contours_for_visualization(xml_file): elif vv.tag!=link+'Point' and sumi>=1: break co_noise.append(np.array(c_t_in)) - return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len + return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len def get_images_of_ground_truth( gt_list, @@ -870,7 +900,7 @@ def get_images_of_ground_truth( types_graphic_label = list(types_graphic_dict.values()) - labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)] + labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255), (125,125,255)] region_tags=np.unique([x for x in alltags if x.endswith('Region')]) @@ -882,6 +912,7 @@ def get_images_of_ground_truth( co_img=[] co_table=[] co_map=[] + co_music=[] co_noise=[] for tag in region_tags: @@ -1123,6 +1154,32 @@ def get_images_of_ground_truth( elif vv.tag!=link+'Point' and sumi>=1: break co_map.append(np.array(c_t_in)) + + if 'musicregion' in keys: + if tag.endswith('}MusicRegion') or tag.endswith('}musicregion'): + #print('sth') + for nn in root1.iter(tag): + c_t_in=[] + sumi=0 + for vv in nn.iter(): + # check the format of coords + if vv.tag==link+'Coords': + coords=bool(vv.attrib) + if coords: + p_h=vv.attrib['points'].split(' ') + c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) ) + break + else: + pass + + + if vv.tag==link+'Point': + c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ]) + sumi+=1 + #print(vv.tag,'in') + elif vv.tag!=link+'Point' and sumi>=1: + break + co_music.append(np.array(c_t_in)) if 'noiseregion' in keys: if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): @@ -1200,6 +1257,10 @@ def get_images_of_ground_truth( erosion_rate = 0#2 dilation_rate = 3#4 co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) + if "musicregion" in elements_with_artificial_class: + erosion_rate = 0#2 + dilation_rate = 3#4 + co_music, img_boundary = update_region_contours(co_music, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) @@ -1227,6 +1288,8 @@ def get_images_of_ground_truth( img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) if 'mapregion' in keys: img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']]) + if 'musicregion' in keys: + img_poly=cv2.fillPoly(img, pts =co_music, color=labels_rgb_color[ config_params['musicregion']]) if 'noiseregion' in keys: img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) @@ -1291,6 +1354,9 @@ def get_images_of_ground_truth( if 'mapregion' in keys: color_label = config_params['mapregion'] img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label)) + if 'musicregion' in keys: + color_label = config_params['musicregion'] + img_poly=cv2.fillPoly(img, pts =co_music, color=(color_label,color_label,color_label)) if 'noiseregion' in keys: color_label = config_params['noiseregion'] img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label))