page alto label generation activated for textline

This commit is contained in:
vahidrezanezhad 2026-03-03 21:12:20 +01:00
parent 4b80e45d91
commit f1d8257496
3 changed files with 211 additions and 171 deletions

View file

@ -92,7 +92,7 @@ def linegt_cli(
tree = ET.parse(dir_xml) tree = ET.parse(dir_xml)
root = tree.getroot() root = tree.getroot()
NS = {"alto": "http://www.loc.gov/standards/alto/ns-v4#"} NS = {'alto': root.tag.split('}')[0].strip('{')}#{"alto": "http://www.loc.gov/standards/alto/ns-v4#"}
results = [] results = []

View file

@ -73,8 +73,14 @@ def main():
is_flag=True, is_flag=True,
help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.", help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.",
) )
@click.option(
"--page_alto",
"-alto",
is_flag=True,
help="If this parameter is set to True, textline label generation is performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.",
)
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images, page_alto):
if config: if config:
with open(config) as f: with open(config) as f:
config_params = json.load(f) config_params = json.load(f)
@ -82,7 +88,7 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di
print("passed") print("passed")
config_params = None config_params = None
gt_list = get_content_of_dir(dir_xml) gt_list = get_content_of_dir(dir_xml)
get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images) get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images, page_alto)
@main.command() @main.command()
@click.option( @click.option(

View file

@ -686,7 +686,7 @@ def get_layout_contours_for_visualization(xml_file):
co_noise.append(np.array(c_t_in)) co_noise.append(np.array(c_t_in))
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len
def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images, page_alto=False):
""" """
Reading the page xml files and write the ground truth images into given output directory. Reading the page xml files and write the ground truth images into given output directory.
""" """
@ -696,190 +696,224 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
ls_org_imgs = os.listdir(dir_images) ls_org_imgs = os.listdir(dir_images)
ls_org_imgs_stem = [os.path.splitext(item)[0] for item in ls_org_imgs] ls_org_imgs_stem = [os.path.splitext(item)[0] for item in ls_org_imgs]
for index in tqdm(range(len(gt_list))): for index in tqdm(range(len(gt_list))):
#try:
print(gt_list[index]) print(gt_list[index])
tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8'))
root1=tree1.getroot()
alltags=[elem.tag for elem in root1.iter()]
link=alltags[0].split('}')[0]+'}'
x_len, y_len = 0, 0 try:
for jj in root1.iter(link+'Page'): if page_alto:
y_len=int(jj.attrib['imageHeight']) tree = ET.parse(dir_in+'/'+gt_list[index])
x_len=int(jj.attrib['imageWidth']) root = tree.getroot()
if 'columns_width' in list(config_params.keys()):
columns_width_dict = config_params['columns_width']
metadata_element = root1.find(link+'Metadata')
num_col = None
for child in metadata_element:
tag2 = child.tag
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
text_comments = child.text
num_col = int(text_comments.split('num_col')[1])
if num_col:
x_new = columns_width_dict[str(num_col)]
y_new = int ( x_new * (y_len / float(x_len)) )
if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')])
co_use_case = []
for tag in region_tags: NS = {'alto': root.tag.split('}')[0].strip('{')}#{"alto": "http://www.loc.gov/standards/alto/ns-v4#"}
tag_endings = ['}PrintSpace','}Border'] x_len, y_len = 0, 0
page = root.find('.//alto:Page', NS)
x_len = int( page.get("WIDTH") )
y_len = int( page.get("HEIGHT") )
else:
tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8'))
root1=tree1.getroot()
alltags=[elem.tag for elem in root1.iter()]
link=alltags[0].split('}')[0]+'}'
x_len, y_len = 0, 0
for jj in root1.iter(link+'Page'):
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): if 'columns_width' in list(config_params.keys()):
for nn in root1.iter(tag): columns_width_dict = config_params['columns_width']
c_t_in = [] metadata_element = root1.find(link+'Metadata')
sumi = 0 num_col = None
for vv in nn.iter(): for child in metadata_element:
# check the format of coords tag2 = child.tag
if vv.tag == link + 'Coords': if tag2.endswith('}Comments') or tag2.endswith('}comments'):
coords = bool(vv.attrib) text_comments = child.text
if coords: num_col = int(text_comments.split('num_col')[1])
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
img = np.zeros((y_len, x_len, 3)) if num_col:
x_new = columns_width_dict[str(num_col)]
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) y_new = int ( x_new * (y_len / float(x_len)) )
img_poly = img_poly.astype(np.uint8)
imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
try:
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
except:
x, y , w, h = 0, 0, x_len, y_len
bb_xywh = [x, y, w, h]
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
keys = list(config_params.keys())
if "artificial_class_label" in keys:
artificial_class_rgb_color = (255,255,0)
artificial_class_label = config_params['artificial_class_label']
textline_rgb_color = (255, 0, 0)
if config_params['use_case']=='textline':
region_tags = np.unique([x for x in alltags if x.endswith('TextLine')])
elif config_params['use_case']=='word':
region_tags = np.unique([x for x in alltags if x.endswith('Word')])
elif config_params['use_case']=='glyph':
region_tags = np.unique([x for x in alltags if x.endswith('Glyph')])
elif config_params['use_case']=='printspace':
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')])
co_use_case = []
for tag in region_tags:
if config_params['use_case']=='textline':
tag_endings = ['}TextLine','}textline']
elif config_params['use_case']=='word':
tag_endings = ['}Word','}word']
elif config_params['use_case']=='glyph':
tag_endings = ['}Glyph','}glyph']
elif config_params['use_case']=='printspace':
tag_endings = ['}PrintSpace','}printspace']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
for nn in root1.iter(tag): region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')])
c_t_in = [] co_use_case = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point': for tag in region_tags:
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) tag_endings = ['}PrintSpace','}Border']
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1: if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
break for nn in root1.iter(tag):
co_use_case.append(np.array(c_t_in)) c_t_in = []
sumi = 0
for vv in nn.iter():
if "artificial_class_label" in keys: # check the format of coords
img_boundary = np.zeros((y_len, x_len)) if vv.tag == link + 'Coords':
erosion_rate = 0#1 coords = bool(vv.attrib)
dilation_rate = 2 if coords:
dilation_early = 0 p_h = vv.attrib['points'].split(' ')
erosion_early = 2 c_t_in.append(
co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early) np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
img = np.zeros((y_len, x_len, 3))
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
img_poly = img_poly.astype(np.uint8)
imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
try:
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
except:
x, y , w, h = 0, 0, x_len, y_len
bb_xywh = [x, y, w, h]
img = np.zeros((y_len, x_len, 3))
if output_type == '2d': if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) keys = list(config_params.keys())
if "artificial_class_label" in keys: if "artificial_class_label" in keys:
img_mask = np.copy(img_poly) artificial_class_rgb_color = (255,255,0)
##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label artificial_class_label = config_params['artificial_class_label']
img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label
elif output_type == '3d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color)
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0]
img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1]
img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2]
textline_rgb_color = (255, 0, 0)
if printspace and config_params['use_case']!='printspace': if page_alto:
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] co_use_case = []
for line in root.findall(".//alto:TextLine", NS):
string_el = line.find("alto:String", NS)
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': textline_text = string_el.attrib["CONTENT"] if string_el is not None else None
img_poly = resize_image(img_poly, y_new, x_new)
try: polygon_el = line.find("alto:Shape/alto:Polygon", NS)
xml_file_stem = os.path.splitext(gt_list[index])[0] if polygon_el is None:
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) continue
except:
xml_file_stem = os.path.splitext(gt_list[index])[0] points = polygon_el.attrib["POINTS"].split()
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly) coords = [
(int(points[i]), int(points[i + 1]))
if dir_images: for i in range(0, len(points), 2)
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)] ]
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
co_use_case.append( np.array(coords, dtype=np.int32) )
else:
if config_params['use_case']=='textline':
region_tags = np.unique([x for x in alltags if x.endswith('TextLine')])
elif config_params['use_case']=='word':
region_tags = np.unique([x for x in alltags if x.endswith('Word')])
elif config_params['use_case']=='glyph':
region_tags = np.unique([x for x in alltags if x.endswith('Glyph')])
elif config_params['use_case']=='printspace':
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')])
co_use_case = []
for tag in region_tags:
if config_params['use_case']=='textline':
tag_endings = ['}TextLine','}textline']
elif config_params['use_case']=='word':
tag_endings = ['}Word','}word']
elif config_params['use_case']=='glyph':
tag_endings = ['}Glyph','}glyph']
elif config_params['use_case']=='printspace':
tag_endings = ['}PrintSpace','}printspace']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
if "artificial_class_label" in keys:
img_boundary = np.zeros((y_len, x_len))
erosion_rate = 0#1
dilation_rate = 2
dilation_early = 0
erosion_early = 2
co_use_case, img_boundary = update_region_contours(co_use_case, img_boundary, erosion_rate, dilation_rate, y_len, x_len, dilation_early=dilation_early, erosion_early=erosion_early)
img = np.zeros((y_len, x_len, 3))
if output_type == '2d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
##img_poly[:,:][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=1)] = artificial_class_label
img_poly[:,:][img_boundary[:,:]==1] = artificial_class_label
elif output_type == '3d':
img_poly = cv2.fillPoly(img, pts=co_use_case, color=textline_rgb_color)
if "artificial_class_label" in keys:
img_mask = np.copy(img_poly)
img_poly[:,:,0][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[0]
img_poly[:,:,1][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[1]
img_poly[:,:,2][(img_boundary[:,:]==1) & (img_mask[:,:,0]!=255)] = artificial_class_rgb_color[2]
if printspace and config_params['use_case']!='printspace': if printspace and config_params['use_case']!='printspace':
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :] img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace': if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new) img_poly = resize_image(img_poly, y_new, x_new)
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)
try:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
except:
xml_file_stem = os.path.splitext(gt_list[index])[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
if dir_images:
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)]
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace and config_params['use_case']!='printspace':
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new)
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)
except:
pass
if config_file and config_params['use_case']=='layout': if config_file and config_params['use_case']=='layout':
keys = list(config_params.keys()) keys = list(config_params.keys())