From 2dc385777091468833645d781d51f6f222fca5a2 Mon Sep 17 00:00:00 2001 From: Kai Labusch Date: Thu, 2 Jul 2020 11:37:54 +0200 Subject: [PATCH] make tools more robust against glitches within the input files --- cli.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/cli.py b/cli.py index d8923bd..b467944 100644 --- a/cli.py +++ b/cli.py @@ -174,6 +174,7 @@ def ned(tsv, ner_result, ned_rest_endpoint, json_file=None, threshold=None): rids = [] entity = "" entity_type = None + tsv['ID'] = '-' def check_entity(tag): nonlocal entity, entity_type, rids @@ -195,7 +196,10 @@ def ned(tsv, ner_result, ned_rest_endpoint, json_file=None, threshold=None): entity = "" entity_type = None - for rid, row in tsv.iterrows(): + ner_tmp = tsv.copy() + ner_tmp.loc[~ner_tmp['NE-TAG'].isin(['O', 'B-PER', 'B-LOC','B-ORG', 'I-PER', 'I-LOC', 'I-ORG']), 'NE-TAG'] = 'O' + + for rid, row in ner_tmp.iterrows(): check_entity(row['NE-TAG']) @@ -205,7 +209,7 @@ def ned(tsv, ner_result, ned_rest_endpoint, json_file=None, threshold=None): entity += " " if entity != "" else "" - entity += row['TOKEN'] + entity += str(row['TOKEN']) rids.append(rid) @@ -321,12 +325,10 @@ def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom'] - supported_ent = {'PER', 'LOC', 'ORG'} - if noproxy: os.environ['no_proxy'] = '*' - tsv = pd.read_csv(tsv_file, sep='\t', comment='#', quoting=3) + tsv = pd.read_csv(tsv_file, sep='\t', comment='#', quoting=3).rename(columns={'GND-ID': 'ID'}) parts = extract_doc_links(tsv_file) @@ -343,8 +345,9 @@ def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, tmp = tsv.copy() tmp['sen'] = (tmp['No.'] == 0).cumsum() + tmp.loc[~tmp['NE-TAG'].isin(['O', 'B-PER', 'B-LOC','B-ORG', 'I-PER', 'I-LOC', 'I-ORG']), 'NE-TAG'] = 'O' - ner_result = [[{'word': str(row.TOKEN), 'prediction': row['NE-TAG'] if len(row['NE-TAG']) == 1 or row['NE-TAG'][-3:] in supported_ent else 'O' } for _, row in sen.iterrows()] + ner_result = [[{'word': str(row.TOKEN), 'prediction': row['NE-TAG']} for _, row in sen.iterrows()] for _, sen in tmp.groupby('sen')] else: raise RuntimeError("Either NER rest endpoint or NER-TAG information within tsv_file required.")