From f1d8257496307aa4a7428596408dad2780303b68 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 3 Mar 2026 21:12:20 +0100 Subject: [PATCH] page alto label generation activated for textline --- src/eynollah/training/extract_line_gt.py | 2 +- .../training/generate_gt_for_training.py | 10 +- src/eynollah/training/gt_gen_utils.py | 370 ++++++++++-------- 3 files changed, 211 insertions(+), 171 deletions(-) diff --git a/src/eynollah/training/extract_line_gt.py b/src/eynollah/training/extract_line_gt.py index 3d508bc..819bac1 100644 --- a/src/eynollah/training/extract_line_gt.py +++ b/src/eynollah/training/extract_line_gt.py @@ -92,7 +92,7 @@ def linegt_cli( tree = ET.parse(dir_xml) root = tree.getroot() - NS = {"alto": "http://www.loc.gov/standards/alto/ns-v4#"} + NS = {'alto': root.tag.split('}')[0].strip('{')}#{"alto": "http://www.loc.gov/standards/alto/ns-v4#"} results = [] diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index 1e820f0..899675e 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -73,8 +73,14 @@ def main(): is_flag=True, help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.", ) +@click.option( + "--page_alto", + "-alto", + is_flag=True, + help="If this parameter is set to True, textline label generation is performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.", +) -def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): +def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images, page_alto): if config: with open(config) as f: config_params = json.load(f) @@ -82,7 +88,7 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di print("passed") config_params = None gt_list = get_content_of_dir(dir_xml) - get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images) + get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images, page_alto) @main.command() @click.option( diff --git a/src/eynollah/training/gt_gen_utils.py b/src/eynollah/training/gt_gen_utils.py index 70d48ae..717865f 100644 --- a/src/eynollah/training/gt_gen_utils.py +++ b/src/eynollah/training/gt_gen_utils.py @@ -686,7 +686,7 @@ def get_layout_contours_for_visualization(xml_file): co_noise.append(np.array(c_t_in)) return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len -def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): +def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images, page_alto=False): """ Reading the page xml files and write the ground truth images into given output directory. """ @@ -696,190 +696,224 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ ls_org_imgs = os.listdir(dir_images) ls_org_imgs_stem = [os.path.splitext(item)[0] for item in ls_org_imgs] for index in tqdm(range(len(gt_list))): - #try: print(gt_list[index]) - tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8')) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - x_len, y_len = 0, 0 - for jj in root1.iter(link+'Page'): - y_len=int(jj.attrib['imageHeight']) - x_len=int(jj.attrib['imageWidth']) - - if 'columns_width' in list(config_params.keys()): - columns_width_dict = config_params['columns_width'] - metadata_element = root1.find(link+'Metadata') - num_col = None - for child in metadata_element: - tag2 = child.tag - if tag2.endswith('}Comments') or tag2.endswith('}comments'): - text_comments = child.text - num_col = int(text_comments.split('num_col')[1]) - - if num_col: - x_new = columns_width_dict[str(num_col)] - y_new = int ( x_new * (y_len / float(x_len)) ) - - if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): - region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) - co_use_case = [] + try: + if page_alto: + tree = ET.parse(dir_in+'/'+gt_list[index]) + root = tree.getroot() - for tag in region_tags: - tag_endings = ['}PrintSpace','}Border'] + NS = {'alto': root.tag.split('}')[0].strip('{')}#{"alto": "http://www.loc.gov/standards/alto/ns-v4#"} + x_len, y_len = 0, 0 + + page = root.find('.//alto:Page', NS) + + x_len = int( page.get("WIDTH") ) + y_len = int( page.get("HEIGHT") ) + + else: + tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' + + + x_len, y_len = 0, 0 + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) - if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): - for nn in root1.iter(tag): - c_t_in = [] - sumi = 0 - for vv in nn.iter(): - # check the format of coords - if vv.tag == link + 'Coords': - coords = bool(vv.attrib) - if coords: - p_h = vv.attrib['points'].split(' ') - c_t_in.append( - np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) - break - else: - pass - - if vv.tag == link + 'Point': - c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) - sumi += 1 - elif vv.tag != link + 'Point' and sumi >= 1: - break - co_use_case.append(np.array(c_t_in)) + if 'columns_width' in list(config_params.keys()): + columns_width_dict = config_params['columns_width'] + metadata_element = root1.find(link+'Metadata') + num_col = None + for child in metadata_element: + tag2 = child.tag + if tag2.endswith('}Comments') or tag2.endswith('}comments'): + text_comments = child.text + num_col = int(text_comments.split('num_col')[1]) - img = np.zeros((y_len, x_len, 3)) - - img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) - - img_poly = img_poly.astype(np.uint8) - - imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY) - _, thresh = cv2.threshold(imgray, 0, 255, 0) - - contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) - - cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) - - try: - cnt = contours[np.argmax(cnt_size)] - x, y, w, h = cv2.boundingRect(cnt) - except: - x, y , w, h = 0, 0, x_len, y_len - - bb_xywh = [x, y, w, h] - - - if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): - keys = list(config_params.keys()) - if "artificial_class_label" in keys: - artificial_class_rgb_color = (255,255,0) - artificial_class_label = config_params['artificial_class_label'] - - textline_rgb_color = (255, 0, 0) - - if config_params['use_case']=='textline': - region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) - elif config_params['use_case']=='word': - region_tags = np.unique([x for x in alltags if x.endswith('Word')]) - elif config_params['use_case']=='glyph': - region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) - elif config_params['use_case']=='printspace': - region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) - - co_use_case = [] - - for tag in region_tags: - if config_params['use_case']=='textline': - tag_endings = ['}TextLine','}textline'] - elif config_params['use_case']=='word': - tag_endings = ['}Word','}word'] - elif config_params['use_case']=='glyph': - tag_endings = ['}Glyph','}glyph'] - elif config_params['use_case']=='printspace': - tag_endings = ['}PrintSpace','}printspace'] + if num_col: + x_new = columns_width_dict[str(num_col)] + y_new = int ( x_new * (y_len / float(x_len)) ) - if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): - for nn in root1.iter(tag): - c_t_in = [] - sumi = 0 - for vv in nn.iter(): - # check the format of coords - if vv.tag == link + 'Coords': - coords = bool(vv.attrib) - if coords: - p_h = vv.attrib['points'].split(' ') - c_t_in.append( - np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) - break - else: - pass + if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) + co_use_case = [] - if vv.tag == link + 'Point': - c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) - sumi += 1 - elif vv.tag != link + 'Point' and sumi >= 1: - break - co_use_case.append(np.array(c_t_in)) - - - if "artificial_class_label" in keys: - img_boundary = np.zeros((y_len, x_len)) - erosion_rate = 0#1 - dilation_rate = 2 - dilation_early = 0 - erosion_early = 2 - co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early) - + for tag in region_tags: + tag_endings = ['}PrintSpace','}Border'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + img = np.zeros((y_len, x_len, 3)) + + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + + img_poly = img_poly.astype(np.uint8) + + imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + + try: + cnt = contours[np.argmax(cnt_size)] + x, y, w, h = cv2.boundingRect(cnt) + except: + x, y , w, h = 0, 0, x_len, y_len + + bb_xywh = [x, y, w, h] - img = np.zeros((y_len, x_len, 3)) - if output_type == '2d': - img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + + if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): + keys = list(config_params.keys()) if "artificial_class_label" in keys: - img_mask = np.copy(img_poly) - ##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label - img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label - elif output_type == '3d': - img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) - if "artificial_class_label" in keys: - img_mask = np.copy(img_poly) - img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0] - img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1] - img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2] + artificial_class_rgb_color = (255,255,0) + artificial_class_label = config_params['artificial_class_label'] + textline_rgb_color = (255, 0, 0) - if printspace and config_params['use_case']!='printspace': - img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] - - - if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': - img_poly = resize_image(img_poly, y_new, x_new) + if page_alto: + co_use_case = [] + for line in root.findall(".//alto:TextLine", NS): + string_el = line.find("alto:String", NS) + textline_text = string_el.attrib["CONTENT"] if string_el is not None else None - try: - xml_file_stem = os.path.splitext(gt_list[index])[0] - cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) - except: - xml_file_stem = os.path.splitext(gt_list[index])[0] - cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) - - if dir_images: - org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] - img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + polygon_el = line.find("alto:Shape/alto:Polygon", NS) + if polygon_el is None: + continue + + points = polygon_el.attrib["POINTS"].split() + coords = [ + (int(points[i]), int(points[i + 1])) + for i in range(0, len(points), 2) + ] + + co_use_case.append( np.array(coords, dtype=np.int32) ) + else: + if config_params['use_case']=='textline': + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + elif config_params['use_case']=='word': + region_tags = np.unique([x for x in alltags if x.endswith('Word')]) + elif config_params['use_case']=='glyph': + region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) + elif config_params['use_case']=='printspace': + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + + co_use_case = [] + + for tag in region_tags: + if config_params['use_case']=='textline': + tag_endings = ['}TextLine','}textline'] + elif config_params['use_case']=='word': + tag_endings = ['}Word','}word'] + elif config_params['use_case']=='glyph': + tag_endings = ['}Glyph','}glyph'] + elif config_params['use_case']=='printspace': + tag_endings = ['}PrintSpace','}printspace'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + + if "artificial_class_label" in keys: + img_boundary = np.zeros((y_len, x_len)) + erosion_rate = 0#1 + dilation_rate = 2 + dilation_early = 0 + erosion_early = 2 + co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early) + + img = np.zeros((y_len, x_len, 3)) + if output_type == '2d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + if "artificial_class_label" in keys: + img_mask = np.copy(img_poly) + ##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label + img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label + elif output_type == '3d': + img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color) + if "artificial_class_label" in keys: + img_mask = np.copy(img_poly) + img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0] + img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1] + img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2] + + if printspace and config_params['use_case']!='printspace': - img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': - img_org = resize_image(img_org, y_new, x_new) - - cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) + img_poly = resize_image(img_poly, y_new, x_new) - + try: + xml_file_stem = os.path.splitext(gt_list[index])[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + except: + xml_file_stem = os.path.splitext(gt_list[index])[0] + cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) + + if dir_images: + org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] + img_org = cv2.imread(os.path.join(dir_images, org_image_name)) + + if printspace and config_params['use_case']!='printspace': + img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] + + if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': + img_org = resize_image(img_org, y_new, x_new) + + cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org) + + except: + pass + if config_file and config_params['use_case']=='layout': keys = list(config_params.keys())