diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c14b1f1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.egg-info +__pycache__ diff --git a/requirements.txt b/requirements.txt index 0f60ff5..4bc6dad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -numpy +ocrd >= 2.23.2 pandas -click requests matplotlib diff --git a/tsvtools/cli.py b/tsvtools/cli.py index 00208b8..a2b5cb8 100644 --- a/tsvtools/cli.py +++ b/tsvtools/cli.py @@ -1,13 +1,16 @@ +import json +import glob +import re +import os +from io import StringIO + import numpy as np import click import pandas as pd -from io import StringIO -import os -import xml.etree.ElementTree as ET import requests -import json -import glob -import re + +from ocrd_models.ocrd_page import parse +from ocrd_utils import bbox_from_points from .ned import ned from .ner import ner @@ -75,12 +78,10 @@ def annotate_tsv(tsv_file, annotated_tsv_file): @click.option('--max-confidence', type=float, default=None) def page2tsv(page_xml_file, tsv_out_file, purpose, image_url, ner_rest_endpoint, ned_rest_endpoint, noproxy, scale_factor, ned_threshold, min_confidence, max_confidence): - if purpose == "NERD": out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom', 'conf'] elif purpose == "OCR": out_columns = ['TEXT', 'url_id', 'left', 'right', 'top', 'bottom', 'conf'] - if min_confidence is not None and max_confidence is not None: out_columns += ['ocrconf'] else: @@ -89,57 +90,36 @@ def page2tsv(page_xml_file, tsv_out_file, purpose, image_url, ner_rest_endpoint, if noproxy: os.environ['no_proxy'] = '*' - tree = ET.parse(page_xml_file) - xmlns = tree.getroot().tag.split('}')[0].strip('{') - urls = [] if os.path.exists(tsv_out_file): parts = extract_doc_links(tsv_out_file) - urls = [part['url'] for part in parts] else: pd.DataFrame([], columns=out_columns).to_csv(tsv_out_file, sep="\t", quoting=3, index=False) + pcgts = parse(page_xml_file) tsv = [] line_info = [] - for rgn_number, region in enumerate(tree.findall('.//{%s}TextRegion' % xmlns)): - - for text_line in region.findall('.//{%s}TextLine' % xmlns): - - points = [int(scale_factor * float(pos)) for coords in text_line.findall('./{%s}Coords' % xmlns) for p in - coords.attrib['points'].split(' ') for pos in p.split(',')] - - x_points, y_points = points[0::2], points[1::2] - left, right, top, bottom = min(x_points), max(x_points), min(y_points), max(y_points) + for region_idx, region in enumerate(pcgts.get_Page().get_AllRegions(classes=['Text'], order='reading-order')): + for text_line in region.get_TextLine(): + left, top, right, bottom = [int(scale_factor * x) for x in bbox_from_points(text_line.get_Coords().points)] if min_confidence is not None and max_confidence is not None: - conf = np.max([float(text.attrib['conf']) for text in text_line.findall('./{%s}TextEquiv' % xmlns)]) + conf = np.max([textequiv.conf for textequiv in text_line.get_TextEquiv()]) else: conf = np.nan line_info.append((len(urls), left, right, top, bottom, conf)) - for word in text_line.findall('./{%s}Word' % xmlns): - - for text_equiv in word.findall('./{%s}TextEquiv/{%s}Unicode' % (xmlns, xmlns)): - text = text_equiv.text - - points = [] - - for coords in word.findall('./{%s}Coords' % xmlns): - - # transform OCR coordinates using `scale_factor` to derive - # correct coordinates for the web presentation image - points += [int(scale_factor * float(pos)) - for p in coords.attrib['points'].split(' ') for pos in p.split(',')] - - x_points, y_points = points[0::2], points[1::2] - - left, right, top, bottom = min(x_points), max(x_points), min(y_points), max(y_points) + for word in text_line.get_Word(): + for text_equiv in word.get_TextEquiv(): + # transform OCR coordinates using `scale_factor` to derive + # correct coordinates for the web presentation image + left, top, right, bottom = [int(scale_factor * x) for x in bbox_from_points(word.get_Coords().points)] - tsv.append((rgn_number, len(line_info)-1, left + (right - left) / 2.0, text, - len(urls), left, right, top, bottom)) + tsv.append((region_idx, len(line_info) - 1, left + (right - left) / 2.0, + text_equiv.get_Unicode(), len(urls), left, right, top, bottom)) line_info = pd.DataFrame(line_info, columns=['url_id', 'left', 'right', 'top', 'bottom', 'conf'])