diff --git a/requirements.txt b/requirements.txt index e9b9289..b7eb78b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ ocrd >= 2.23.2 pandas matplotlib +qurator-sbb-tools \ No newline at end of file diff --git a/setup.py b/setup.py index 9b8d638..eba3415 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ setup( "annotate-tsv=tsvtools.cli:annotate_tsv", "page2tsv=tsvtools.cli:page2tsv", "tsv2page=tsvtools.cli:tsv2page", - "find-entities=tsvtools.cli:find_entities", "make-page2tsv-commands=tsvtools.cli:make_page2tsv_commands" ] }, diff --git a/tsvtools/cli.py b/tsvtools/cli.py index e986b7b..6c26008 100644 --- a/tsvtools/cli.py +++ b/tsvtools/cli.py @@ -1,4 +1,3 @@ -import json import glob import re import os @@ -14,12 +13,9 @@ from lxml import etree as ET from ocrd_models.ocrd_page import parse from ocrd_utils import bbox_from_points -from .ned import ned -from .ner import ner -from .tsv import read_tsv, write_tsv, extract_doc_links +from qurator.utils.tsv import read_tsv, write_tsv, extract_doc_links from .ocr import get_conf_color - @click.command() @click.argument('tsv-file', type=click.Path(exists=True), required=True, nargs=1) @click.argument('url-file', type=click.Path(exists=False), required=True, nargs=1) @@ -218,59 +214,6 @@ def tsv2page(output_filename, keep_words, page_file, tsv_file): f.write(ET.tostring(tree, pretty_print=True).decode('utf-8')) -@click.command() -@click.argument('tsv-file', type=click.Path(exists=True), required=True, nargs=1) -@click.argument('tsv-out-file', type=click.Path(), required=True, nargs=1) -@click.option('--ner-rest-endpoint', type=str, default=None, - help="REST endpoint of sbb_ner service. See https://github.com/qurator-spk/sbb_ner for details.") -@click.option('--ned-rest-endpoint', type=str, default=None, - help="REST endpoint of sbb_ned service. See https://github.com/qurator-spk/sbb_ned for details.") -@click.option('--ned-json-file', type=str, default=None) -@click.option('--noproxy', type=bool, is_flag=True, help='disable proxy. default: proxy is enabled.') -@click.option('--ned-threshold', type=float, default=None) -@click.option('--ned-priority', type=int, default=1) -def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, ned_json_file, noproxy, ned_threshold, - ned_priority): - - if noproxy: - os.environ['no_proxy'] = '*' - - tsv, urls = read_tsv(tsv_file) - - try: - if ner_rest_endpoint is not None: - - tsv, ner_result = ner(tsv, ner_rest_endpoint) - - elif os.path.exists(tsv_file): - - print('Using NER information that is already contained in file: {}'.format(tsv_file)) - - tmp = tsv.copy() - tmp['sen'] = (tmp['No.'] == 0).cumsum() - tmp.loc[~tmp['NE-TAG'].isin(['O', 'B-PER', 'B-LOC', 'B-ORG', 'I-PER', 'I-LOC', 'I-ORG']), 'NE-TAG'] = 'O' - - ner_result = [[{'word': str(row.TOKEN), 'prediction': row['NE-TAG']} for _, row in sen.iterrows()] - for _, sen in tmp.groupby('sen')] - else: - raise RuntimeError("Either NER rest endpoint or NER-TAG information within tsv_file required.") - - if ned_rest_endpoint is not None: - - tsv, ned_result = ned(tsv, ner_result, ned_rest_endpoint, json_file=ned_json_file, threshold=ned_threshold, - priority=ned_priority) - - if ned_json_file is not None and not os.path.exists(ned_json_file): - - with open(ned_json_file, "w") as fp_json: - json.dump(ned_result, fp_json, indent=2, separators=(',', ': ')) - - write_tsv(tsv, urls, tsv_out_file) - - except requests.HTTPError as e: - print(e) - - @click.command() @click.option('--xls-file', type=click.Path(exists=True), default=None, help="Read parameters from xls-file. Expected columns: Filename, iiif_url, scale_factor.") diff --git a/tsvtools/ned.py b/tsvtools/ned.py deleted file mode 100644 index 144c66b..0000000 --- a/tsvtools/ned.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -import requests -import json - - -def ned(tsv, ner_result, ned_rest_endpoint, json_file=None, threshold=None, priority=None): - - if json_file is not None and os.path.exists(json_file): - - print('Loading {}'.format(json_file)) - - with open(json_file, "r") as fp: - ned_result = json.load(fp) - - else: - - resp = requests.post(url=ned_rest_endpoint + '/parse', json=ner_result) - - resp.raise_for_status() - - ner_parsed = json.loads(resp.content) - - ned_rest_endpoint = ned_rest_endpoint + '/ned?return_full=' + str(int(json_file is not None)).lower() - - if priority is not None: - ned_rest_endpoint += "&priority=" + str(int(priority)) - - resp = requests.post(url=ned_rest_endpoint, json=ner_parsed, timeout=3600000) - - resp.raise_for_status() - - ned_result = json.loads(resp.content) - - rids = [] - entity = "" - entity_type = None - tsv['ID'] = '-' - tsv['conf'] = '-' - - def check_entity(tag): - nonlocal entity, entity_type, rids - - if (entity != "") and ((tag == 'O') or tag.startswith('B-') or (tag[2:] != entity_type)): - - eid = entity + "-" + entity_type - - if eid in ned_result: - if 'ranking' in ned_result[eid]: - ranking = ned_result[eid]['ranking'] - - # tsv.loc[rids, 'ID'] = ranking[0][1]['wikidata'] - # if threshold is None or ranking[0][1]['proba_1'] >= threshold else '' - - tmp = "|".join([ranking[i][1]['wikidata'] - for i in range(len(ranking)) - if threshold is None or ranking[i][1]['proba_1'] >= threshold]) - tsv.loc[rids, 'ID'] = tmp if len(tmp) > 0 else '-' - - tmp = ",".join([str(ranking[i][1]['proba_1']) - for i in range(len(ranking)) - if threshold is None or ranking[i][1]['proba_1'] >= threshold]) - - tsv.loc[rids, 'conf'] = tmp if len(tmp) > 0 else '-' - - rids = [] - entity = "" - entity_type = None - - ner_tmp = tsv.copy() - ner_tmp.loc[~ner_tmp['NE-TAG'].isin(['O', 'B-PER', 'B-LOC', 'B-ORG', 'I-PER', 'I-LOC', 'I-ORG']), 'NE-TAG'] = 'O' - - for rid, row in ner_tmp.iterrows(): - - check_entity(row['NE-TAG']) - - if row['NE-TAG'] != 'O': - - entity_type = row['NE-TAG'][2:] - - entity += " " if entity != "" else "" - - entity += str(row['TOKEN']) - - rids.append(rid) - - check_entity('O') - - return tsv, ned_result diff --git a/tsvtools/ner.py b/tsvtools/ner.py deleted file mode 100644 index 8c33c4d..0000000 --- a/tsvtools/ner.py +++ /dev/null @@ -1,49 +0,0 @@ -import pandas as pd -import requests -import unicodedata -import json - - -def ner(tsv, ner_rest_endpoint): - - resp = requests.post(url=ner_rest_endpoint, json={'text': " ".join(tsv.TOKEN.astype(str).tolist())}) - - resp.raise_for_status() - - def iterate_ner_results(result_sentences): - - for sen in result_sentences: - - for token in sen: - - yield unicodedata.normalize('NFC', token['word']), token['prediction'], False - - yield '', '', True - - ner_result = json.loads(resp.content) - - result_sequence = iterate_ner_results(ner_result) - - tsv_result = [] - for idx, row in tsv.iterrows(): - - row_token = unicodedata.normalize('NFC', str(row.TOKEN).replace(' ', '')) - - ner_token_concat = '' - while row_token != ner_token_concat: - - ner_token, ner_tag, sentence_break = next(result_sequence) - ner_token_concat += ner_token - - assert len(row_token) >= len(ner_token_concat) - - if sentence_break: - tsv_result.append((0, '', 'O', 'O', '-', row.url_id, row.left, row.right, row.top, row.bottom)) - else: - tsv_result.append((0, ner_token, ner_tag, 'O', '-', row.url_id, row.left, row.right, row.top, - row.bottom)) - - return pd.DataFrame(tsv_result, columns=['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', - 'left', 'right', 'top', 'bottom']), ner_result - - diff --git a/tsvtools/tsv.py b/tsvtools/tsv.py deleted file mode 100644 index aeafb8a..0000000 --- a/tsvtools/tsv.py +++ /dev/null @@ -1,87 +0,0 @@ -import pandas as pd -import re - - -def read_tsv(tsv_file): - - tsv = pd.read_csv(tsv_file, sep='\t', comment='#', quoting=3).rename(columns={'GND-ID': 'ID'}) - - parts = extract_doc_links(tsv_file) - - urls = [part['url'] for part in parts] - - return tsv, urls - - -def write_tsv(tsv, urls, tsv_out_file): - - if 'conf' in tsv.columns: - out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom', 'conf'] - else: - out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom'] - - if len(urls) == 0: - print('Writing to {}...'.format(tsv_out_file)) - - tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False) - else: - pd.DataFrame([], columns=out_columns).to_csv(tsv_out_file, sep="\t", quoting=3, index=False) - - for url_id, part in tsv.groupby('url_id'): - with open(tsv_out_file, 'a') as f: - f.write('# ' + urls[int(url_id)] + '\n') - - part.to_csv(tsv_out_file, sep="\t", quoting=3, index=False, mode='a', header=False) - - -def extract_doc_links(tsv_file): - parts = [] - - header = None - - with open(tsv_file, 'r') as f: - - text = [] - url = None - - for line in f: - - if header is None: - header = "\t".join(line.split()) + '\n' - continue - - urls = [url for url in - re.findall(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', line)] - - if len(urls) > 0: - if url is not None: - parts.append({"url": url, 'header': header, 'text': "".join(text)}) - text = [] - - url = urls[-1] - else: - if url is None: - continue - - line = '\t'.join(line.split()) - - if line.count('\t') == 2: - line = "\t" + line - - if line.count('\t') >= 3: - text.append(line + '\n') - - continue - - if line.startswith('#'): - continue - - if len(line) == 0: - continue - - print('Line error: |', line, '|Number of Tabs: ', line.count('\t')) - - if url is not None: - parts.append({"url": url, 'header': header, 'text': "".join(text)}) - - return parts