diff --git a/src/eynollah/training/extract_line_gt.py b/src/eynollah/training/extract_line_gt.py index 819bac1..4600e79 100644 --- a/src/eynollah/training/extract_line_gt.py +++ b/src/eynollah/training/extract_line_gt.py @@ -56,12 +56,6 @@ from ..utils import is_image_filename is_flag=True, help="if this parameter set to true, vertical textline images will be excluded.", ) -@click.option( - "--page_alto", - "-alto", - is_flag=True, - help="If this parameter is set to True, text line image cropping and text extraction are performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.", -) def linegt_cli( image, dir_in, @@ -70,7 +64,6 @@ def linegt_cli( pref_of_dataset, do_not_mask_with_textline_contour, exclude_vertical_lines, - page_alto, ): assert bool(dir_in) ^ bool(image), "Set --dir-in or --image-filename, not both" if dir_in: @@ -86,147 +79,69 @@ def linegt_cli( dir_xml = os.path.join(dir_xmls, file_name + '.xml') img = cv2.imread(dir_img) - if page_alto: - h, w = img.shape[:2] - - tree = ET.parse(dir_xml) - root = tree.getroot() + total_bb_coordinates = [] - NS = {'alto': root.tag.split('}')[0].strip('{')}#{"alto": "http://www.loc.gov/standards/alto/ns-v4#"} + tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) + root = tree.getroot() + alltags = [elem.tag for elem in root.iter()] - results = [] - - indexer_textlines = 0 - for line in root.findall(".//alto:TextLine", NS): - string_el = line.find("alto:String", NS) - textline_text = string_el.attrib["CONTENT"] if string_el is not None else None + name_space = alltags[0].split('}')[0] + name_space = name_space.split('{')[1] - polygon_el = line.find("alto:Shape/alto:Polygon", NS) - if polygon_el is None: - continue + region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')]) - points = polygon_el.attrib["POINTS"].split() - coords = [ - (int(points[i]), int(points[i + 1])) - for i in range(0, len(points), 2) - ] - - coords = np.array(coords, dtype=np.int32) - x, y, w, h = cv2.boundingRect(coords) - - - if exclude_vertical_lines and h > 1.4 * w: - img_crop = None - continue - - img_poly_on_img = np.copy(img) + cropped_lines_region_indexer = [] - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[coords], color=(1, 1, 1)) + indexer_text_region = 0 + indexer_textlines = 0 + # FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether + for nn in root.iter(region_tags): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h = child_textlines.attrib['points'].split(' ') + textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]) - mask_poly = mask_poly[y : y + h, x : x + w, :] - img_crop = img_poly_on_img[y : y + h, x : x + w, :] - - if not do_not_mask_with_textline_contour: - img_crop[mask_poly == 0] = 255 - - if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: - img_crop = None - continue - - if textline_text and img_crop is not None: - base_name = os.path.join( - dir_out, file_name + '_line_' + str(indexer_textlines) - ) - if pref_of_dataset: - base_name += '_' + pref_of_dataset - if not do_not_mask_with_textline_contour: - base_name += '_masked' - - with open(base_name + '.txt', 'w') as text_file: - text_file.write(textline_text) - cv2.imwrite(base_name + '.png', img_crop) - indexer_textlines += 1 - - - - - - - - - - - - - - - - - else: - total_bb_coordinates = [] - - tree = ET.parse(dir_xml, parser=ET.XMLParser(encoding="utf-8")) - root = tree.getroot() - alltags = [elem.tag for elem in root.iter()] - - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] - - region_tags = np.unique([x for x in alltags if x.endswith('TextRegion')]) - - cropped_lines_region_indexer = [] - - indexer_text_region = 0 - indexer_textlines = 0 - # FIXME: non recursive, use OCR-D PAGE generateDS API. Or use an existing tool for this purpose altogether - for nn in root.iter(region_tags): - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h = child_textlines.attrib['points'].split(' ') - textline_coords = np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]) - - x, y, w, h = cv2.boundingRect(textline_coords) - - if exclude_vertical_lines and h > 1.4 * w: - img_crop = None - continue - - total_bb_coordinates.append([x, y, w, h]) - - img_poly_on_img = np.copy(img) - - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - mask_poly = mask_poly[y : y + h, x : x + w, :] - img_crop = img_poly_on_img[y : y + h, x : x + w, :] - - if not do_not_mask_with_textline_contour: - img_crop[mask_poly == 0] = 255 - - if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: - img_crop = None - continue + x, y, w, h = cv2.boundingRect(textline_coords) - - if child_textlines.tag.endswith("TextEquiv"): - for cheild_text in child_textlines: - if cheild_text.tag.endswith("Unicode"): - textline_text = cheild_text.text - if textline_text and img_crop is not None: - base_name = os.path.join( - dir_out, file_name + '_line_' + str(indexer_textlines) - ) - if pref_of_dataset: - base_name += '_' + pref_of_dataset - if not do_not_mask_with_textline_contour: - base_name += '_masked' + if exclude_vertical_lines and h > 1.4 * w: + img_crop = None + continue - with open(base_name + '.txt', 'w') as text_file: - text_file.write(textline_text) - cv2.imwrite(base_name + '.png', img_crop) - indexer_textlines += 1 + total_bb_coordinates.append([x, y, w, h]) + + img_poly_on_img = np.copy(img) + + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y : y + h, x : x + w, :] + img_crop = img_poly_on_img[y : y + h, x : x + w, :] + + if not do_not_mask_with_textline_contour: + img_crop[mask_poly == 0] = 255 + + if img_crop.shape[0] == 0 or img_crop.shape[1] == 0: + img_crop = None + continue + + + if child_textlines.tag.endswith("TextEquiv"): + for cheild_text in child_textlines: + if cheild_text.tag.endswith("Unicode"): + textline_text = cheild_text.text + if textline_text and img_crop is not None: + base_name = os.path.join( + dir_out, file_name + '_line_' + str(indexer_textlines) + ) + if pref_of_dataset: + base_name += '_' + pref_of_dataset + if not do_not_mask_with_textline_contour: + base_name += '_masked' + + with open(base_name + '.txt', 'w') as text_file: + text_file.write(textline_text) + cv2.imwrite(base_name + '.png', img_crop) + indexer_textlines += 1 diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index 899675e..1e820f0 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -73,14 +73,8 @@ def main(): is_flag=True, help="if this parameter set to true, generated labels and in the case of provided org images cropping will be imposed and cropped labels and images will be written in output directories.", ) -@click.option( - "--page_alto", - "-alto", - is_flag=True, - help="If this parameter is set to True, textline label generation is performed using PAGE/ALTO files. Otherwise, the default method for PAGE XML files is used.", -) -def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images, page_alto): +def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): if config: with open(config) as f: config_params = json.load(f) @@ -88,7 +82,7 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di print("passed") config_params = None gt_list = get_content_of_dir(dir_xml) - get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images, page_alto) + get_images_of_ground_truth(gt_list,dir_xml,dir_out,type_output, config, config_params, printspace, dir_images, dir_out_images) @main.command() @click.option( diff --git a/src/eynollah/training/gt_gen_utils.py b/src/eynollah/training/gt_gen_utils.py index 717865f..1e5f51a 100644 --- a/src/eynollah/training/gt_gen_utils.py +++ b/src/eynollah/training/gt_gen_utils.py @@ -686,7 +686,7 @@ def get_layout_contours_for_visualization(xml_file): co_noise.append(np.array(c_t_in)) return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_music, co_noise, y_len, x_len -def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images, page_alto=False): +def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): """ Reading the page xml files and write the ground truth images into given output directory. """ @@ -699,94 +699,81 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ print(gt_list[index]) try: - if page_alto: - tree = ET.parse(dir_in+'/'+gt_list[index]) - root = tree.getroot() - - NS = {'alto': root.tag.split('}')[0].strip('{')}#{"alto": "http://www.loc.gov/standards/alto/ns-v4#"} - x_len, y_len = 0, 0 - - page = root.find('.//alto:Page', NS) - - x_len = int( page.get("WIDTH") ) - y_len = int( page.get("HEIGHT") ) - - else: - tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8')) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - - x_len, y_len = 0, 0 - for jj in root1.iter(link+'Page'): - y_len=int(jj.attrib['imageHeight']) - x_len=int(jj.attrib['imageWidth']) - - if 'columns_width' in list(config_params.keys()): - columns_width_dict = config_params['columns_width'] - metadata_element = root1.find(link+'Metadata') - num_col = None - for child in metadata_element: - tag2 = child.tag - if tag2.endswith('}Comments') or tag2.endswith('}comments'): - text_comments = child.text - num_col = int(text_comments.split('num_col')[1]) - - if num_col: - x_new = columns_width_dict[str(num_col)] - y_new = int ( x_new * (y_len / float(x_len)) ) - - if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): - region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) - co_use_case = [] - - for tag in region_tags: - tag_endings = ['}PrintSpace','}Border'] - - if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): - for nn in root1.iter(tag): - c_t_in = [] - sumi = 0 - for vv in nn.iter(): - # check the format of coords - if vv.tag == link + 'Coords': - coords = bool(vv.attrib) - if coords: - p_h = vv.attrib['points'].split(' ') - c_t_in.append( - np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) - break - else: - pass - - if vv.tag == link + 'Point': - c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) - sumi += 1 - elif vv.tag != link + 'Point' and sumi >= 1: - break - co_use_case.append(np.array(c_t_in)) + tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding='utf-8')) + root1=tree1.getroot() + alltags=[elem.tag for elem in root1.iter()] + link=alltags[0].split('}')[0]+'}' - img = np.zeros((y_len, x_len, 3)) + + x_len, y_len = 0, 0 + for jj in root1.iter(link+'Page'): + y_len=int(jj.attrib['imageHeight']) + x_len=int(jj.attrib['imageWidth']) + + if 'columns_width' in list(config_params.keys()): + columns_width_dict = config_params['columns_width'] + metadata_element = root1.find(link+'Metadata') + num_col = None + for child in metadata_element: + tag2 = child.tag + if tag2.endswith('}Comments') or tag2.endswith('}comments'): + text_comments = child.text + num_col = int(text_comments.split('num_col')[1]) - img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) - - img_poly = img_poly.astype(np.uint8) - - imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY) - _, thresh = cv2.threshold(imgray, 0, 255, 0) + if num_col: + x_new = columns_width_dict[str(num_col)] + y_new = int ( x_new * (y_len / float(x_len)) ) + + if printspace or "printspace_as_class_in_layout" in list(config_params.keys()): + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')]) + co_use_case = [] - contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) - - cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) - - try: - cnt = contours[np.argmax(cnt_size)] - x, y, w, h = cv2.boundingRect(cnt) - except: - x, y , w, h = 0, 0, x_len, y_len - - bb_xywh = [x, y, w, h] + for tag in region_tags: + tag_endings = ['}PrintSpace','}Border'] + + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) + break + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) + + img = np.zeros((y_len, x_len, 3)) + + img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1)) + + img_poly = img_poly.astype(np.uint8) + + imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) + + try: + cnt = contours[np.argmax(cnt_size)] + x, y, w, h = cv2.boundingRect(cnt) + except: + x, y , w, h = 0, 0, x_len, y_len + + bb_xywh = [x, y, w, h] if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'): @@ -797,67 +784,49 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_ textline_rgb_color = (255, 0, 0) - if page_alto: - co_use_case = [] - for line in root.findall(".//alto:TextLine", NS): - string_el = line.find("alto:String", NS) - textline_text = string_el.attrib["CONTENT"] if string_el is not None else None + if config_params['use_case']=='textline': + region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + elif config_params['use_case']=='word': + region_tags = np.unique([x for x in alltags if x.endswith('Word')]) + elif config_params['use_case']=='glyph': + region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) + elif config_params['use_case']=='printspace': + region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + + co_use_case = [] - polygon_el = line.find("alto:Shape/alto:Polygon", NS) - if polygon_el is None: - continue - - points = polygon_el.attrib["POINTS"].split() - coords = [ - (int(points[i]), int(points[i + 1])) - for i in range(0, len(points), 2) - ] - - co_use_case.append( np.array(coords, dtype=np.int32) ) - else: + for tag in region_tags: if config_params['use_case']=='textline': - region_tags = np.unique([x for x in alltags if x.endswith('TextLine')]) + tag_endings = ['}TextLine','}textline'] elif config_params['use_case']=='word': - region_tags = np.unique([x for x in alltags if x.endswith('Word')]) + tag_endings = ['}Word','}word'] elif config_params['use_case']=='glyph': - region_tags = np.unique([x for x in alltags if x.endswith('Glyph')]) + tag_endings = ['}Glyph','}glyph'] elif config_params['use_case']=='printspace': - region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace')]) + tag_endings = ['}PrintSpace','}printspace'] - co_use_case = [] - - for tag in region_tags: - if config_params['use_case']=='textline': - tag_endings = ['}TextLine','}textline'] - elif config_params['use_case']=='word': - tag_endings = ['}Word','}word'] - elif config_params['use_case']=='glyph': - tag_endings = ['}Glyph','}glyph'] - elif config_params['use_case']=='printspace': - tag_endings = ['}PrintSpace','}printspace'] - - if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): - for nn in root1.iter(tag): - c_t_in = [] - sumi = 0 - for vv in nn.iter(): - # check the format of coords - if vv.tag == link + 'Coords': - coords = bool(vv.attrib) - if coords: - p_h = vv.attrib['points'].split(' ') - c_t_in.append( - np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) - break - else: - pass - - if vv.tag == link + 'Point': - c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) - sumi += 1 - elif vv.tag != link + 'Point' and sumi >= 1: + if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]): + for nn in root1.iter(tag): + c_t_in = [] + sumi = 0 + for vv in nn.iter(): + # check the format of coords + if vv.tag == link + 'Coords': + coords = bool(vv.attrib) + if coords: + p_h = vv.attrib['points'].split(' ') + c_t_in.append( + np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h])) break - co_use_case.append(np.array(c_t_in)) + else: + pass + + if vv.tag == link + 'Point': + c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))]) + sumi += 1 + elif vv.tag != link + 'Point' and sumi >= 1: + break + co_use_case.append(np.array(c_t_in)) if "artificial_class_label" in keys: