diff --git a/cli.py b/cli.py index 5999a2a..8c92e91 100644 --- a/cli.py +++ b/cli.py @@ -105,7 +105,7 @@ def extract_doc_links(tsv_file): def ner(tsv, ner_rest_endpoint): - resp = requests.post(url=ner_rest_endpoint, json={'text': " ".join(tsv.TOKEN.tolist())}) + resp = requests.post(url=ner_rest_endpoint, json={'text': " ".join(tsv.TOKEN.astype(str).tolist())}) resp.raise_for_status() @@ -126,7 +126,7 @@ def ner(tsv, ner_rest_endpoint): tsv_result = [] for idx, row in tsv.iterrows(): - row_token = unicodedata.normalize('NFC', row.TOKEN.replace(' ', '')) + row_token = unicodedata.normalize('NFC', str(row.TOKEN).replace(' ', '')) ner_token_concat = '' while row_token != ner_token_concat: @@ -146,7 +146,7 @@ def ner(tsv, ner_rest_endpoint): 'left', 'right', 'top', 'bottom']), ner_result -def ned(tsv, ner_result, ned_rest_endpoint): +def ned(tsv, ner_result, ned_rest_endpoint, return_full=False, threshold=None): resp = requests.post(url=ned_rest_endpoint + '/parse', json=ner_result) @@ -154,7 +154,11 @@ def ned(tsv, ner_result, ned_rest_endpoint): ner_parsed = json.loads(resp.content) - resp = requests.post(url=ned_rest_endpoint + '/ned', json=ner_parsed, timeout=3600000) + ned_rest_endpoint = ned_rest_endpoint + '/ned?return_full=' + str(return_full).lower() + + ned_rest_endpoint += '&threshold={}'.format(threshold) if threshold is not None else '' + + resp = requests.post(url=ned_rest_endpoint, json=ner_parsed, timeout=3600000) resp.raise_for_status() @@ -163,21 +167,28 @@ def ned(tsv, ner_result, ned_rest_endpoint): rids = [] entity = "" entity_type = None - for rid, row in tsv.iterrows(): - if (entity != "") and ((row['NE-TAG'] == 'O') or (row['NE-TAG'].startswith('B-'))): + def check_entity(tag): + nonlocal entity, entity_type, rids + + if (entity != "") and ((tag == 'O') or tag.startswith('B-') or (tag[2:] != entity_type)): eid = entity + "-" + entity_type if eid in ned_result: - candidates = ned_result[eid] + if 'ranking' in ned_result[eid]: + ranking = ned_result[eid]['ranking'] - tsv.loc[rids, 'ID'] = candidates[0][1]['wikidata'] + tsv.loc[rids, 'ID'] = ranking[0][1]['wikidata'] rids = [] entity = "" entity_type = None + for rid, row in tsv.iterrows(): + + check_entity(row['NE-TAG']) + if row['NE-TAG'] != 'O': entity_type = row['NE-TAG'][2:] @@ -188,7 +199,9 @@ def ned(tsv, ner_result, ned_rest_endpoint): rids.append(rid) - return tsv + check_entity('O') + + return tsv, ned_result @click.command() @@ -201,7 +214,9 @@ def ned(tsv, ner_result, ned_rest_endpoint): help="REST endpoint of sbb_ned service. See https://github.com/qurator-spk/sbb_ned for details.") @click.option('--noproxy', type=bool, is_flag=True, help='disable proxy. default: enabled.') @click.option('--scale-factor', type=float, default=0.5685, help='default: 0.5685') -def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest_endpoint, noproxy, scale_factor): +@click.option('--ned-threshold', type=float, default=None) +def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest_endpoint, noproxy, scale_factor, + ned_threshold): out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom'] @@ -257,7 +272,8 @@ def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest vlinecenter = pd.DataFrame(tsv[['line', 'top']].groupby('line', sort=False).mean().top + (tsv[['line', 'bottom']].groupby('line', sort=False).mean().bottom - - tsv[['line', 'top']].groupby('line', sort=False).mean().top) / 2, columns=['vlinecenter']) + tsv[['line', 'top']].groupby('line', sort=False).mean().top) / 2, + columns=['vlinecenter']) tsv = tsv.merge(vlinecenter, left_on='line', right_index=True) @@ -274,7 +290,7 @@ def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest if ned_rest_endpoint is not None: - tsv = ned(tsv, ner_result, ned_rest_endpoint) + tsv, _ = ned(tsv, ner_result, ned_rest_endpoint, threshold=ned_threshold) tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False, mode='a', header=False) except requests.HTTPError as e: @@ -288,7 +304,8 @@ def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest help="REST endpoint of sbb_ner service. See https://github.com/qurator-spk/sbb_ner for details.") @click.option('--ned-rest-endpoint', type=str, default=None, help="REST endpoint of sbb_ned service. See https://github.com/qurator-spk/sbb_ned for details.") -def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint): +@click.option('--ned-json-file', type=str, default=None) +def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, ned_json_file): tsv = pd.read_csv(tsv_file, sep='\t', comment='#', quoting=3) @@ -299,8 +316,13 @@ def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint): if ned_rest_endpoint is not None: - tsv = ned(tsv, ner_result, ned_rest_endpoint) + tsv, ned_result = ned(tsv, ner_result, ned_rest_endpoint, return_full=ned_json_file is not None) - tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False, mode='a', header=False) + if ned_json_file is not None: + + with open(ned_json_file, "w") as fp_json: + json.dump(ned_result, fp_json, indent=2, separators=(',', ': ')) + + tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False) except requests.HTTPError as e: print(e)