add proper NED support

pull/2/head
Kai 5 years ago
parent 24fd7245f5
commit c7f4b6fe53

@ -105,7 +105,7 @@ def extract_doc_links(tsv_file):
def ner(tsv, ner_rest_endpoint): def ner(tsv, ner_rest_endpoint):
resp = requests.post(url=ner_rest_endpoint, json={'text': " ".join(tsv.TOKEN.tolist())}) resp = requests.post(url=ner_rest_endpoint, json={'text': " ".join(tsv.TOKEN.astype(str).tolist())})
resp.raise_for_status() resp.raise_for_status()
@ -126,7 +126,7 @@ def ner(tsv, ner_rest_endpoint):
tsv_result = [] tsv_result = []
for idx, row in tsv.iterrows(): for idx, row in tsv.iterrows():
row_token = unicodedata.normalize('NFC', row.TOKEN.replace(' ', '')) row_token = unicodedata.normalize('NFC', str(row.TOKEN).replace(' ', ''))
ner_token_concat = '' ner_token_concat = ''
while row_token != ner_token_concat: while row_token != ner_token_concat:
@ -146,7 +146,7 @@ def ner(tsv, ner_rest_endpoint):
'left', 'right', 'top', 'bottom']), ner_result 'left', 'right', 'top', 'bottom']), ner_result
def ned(tsv, ner_result, ned_rest_endpoint): def ned(tsv, ner_result, ned_rest_endpoint, return_full=False, threshold=None):
resp = requests.post(url=ned_rest_endpoint + '/parse', json=ner_result) resp = requests.post(url=ned_rest_endpoint + '/parse', json=ner_result)
@ -154,7 +154,11 @@ def ned(tsv, ner_result, ned_rest_endpoint):
ner_parsed = json.loads(resp.content) ner_parsed = json.loads(resp.content)
resp = requests.post(url=ned_rest_endpoint + '/ned', json=ner_parsed, timeout=3600000) ned_rest_endpoint = ned_rest_endpoint + '/ned?return_full=' + str(return_full).lower()
ned_rest_endpoint += '&threshold={}'.format(threshold) if threshold is not None else ''
resp = requests.post(url=ned_rest_endpoint, json=ner_parsed, timeout=3600000)
resp.raise_for_status() resp.raise_for_status()
@ -163,21 +167,28 @@ def ned(tsv, ner_result, ned_rest_endpoint):
rids = [] rids = []
entity = "" entity = ""
entity_type = None entity_type = None
for rid, row in tsv.iterrows():
if (entity != "") and ((row['NE-TAG'] == 'O') or (row['NE-TAG'].startswith('B-'))): def check_entity(tag):
nonlocal entity, entity_type, rids
if (entity != "") and ((tag == 'O') or tag.startswith('B-') or (tag[2:] != entity_type)):
eid = entity + "-" + entity_type eid = entity + "-" + entity_type
if eid in ned_result: if eid in ned_result:
candidates = ned_result[eid] if 'ranking' in ned_result[eid]:
ranking = ned_result[eid]['ranking']
tsv.loc[rids, 'ID'] = candidates[0][1]['wikidata'] tsv.loc[rids, 'ID'] = ranking[0][1]['wikidata']
rids = [] rids = []
entity = "" entity = ""
entity_type = None entity_type = None
for rid, row in tsv.iterrows():
check_entity(row['NE-TAG'])
if row['NE-TAG'] != 'O': if row['NE-TAG'] != 'O':
entity_type = row['NE-TAG'][2:] entity_type = row['NE-TAG'][2:]
@ -188,7 +199,9 @@ def ned(tsv, ner_result, ned_rest_endpoint):
rids.append(rid) rids.append(rid)
return tsv check_entity('O')
return tsv, ned_result
@click.command() @click.command()
@ -201,7 +214,9 @@ def ned(tsv, ner_result, ned_rest_endpoint):
help="REST endpoint of sbb_ned service. See https://github.com/qurator-spk/sbb_ned for details.") help="REST endpoint of sbb_ned service. See https://github.com/qurator-spk/sbb_ned for details.")
@click.option('--noproxy', type=bool, is_flag=True, help='disable proxy. default: enabled.') @click.option('--noproxy', type=bool, is_flag=True, help='disable proxy. default: enabled.')
@click.option('--scale-factor', type=float, default=0.5685, help='default: 0.5685') @click.option('--scale-factor', type=float, default=0.5685, help='default: 0.5685')
def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest_endpoint, noproxy, scale_factor): @click.option('--ned-threshold', type=float, default=None)
def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest_endpoint, noproxy, scale_factor,
ned_threshold):
out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom'] out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom']
@ -257,7 +272,8 @@ def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest
vlinecenter = pd.DataFrame(tsv[['line', 'top']].groupby('line', sort=False).mean().top + vlinecenter = pd.DataFrame(tsv[['line', 'top']].groupby('line', sort=False).mean().top +
(tsv[['line', 'bottom']].groupby('line', sort=False).mean().bottom - (tsv[['line', 'bottom']].groupby('line', sort=False).mean().bottom -
tsv[['line', 'top']].groupby('line', sort=False).mean().top) / 2, columns=['vlinecenter']) tsv[['line', 'top']].groupby('line', sort=False).mean().top) / 2,
columns=['vlinecenter'])
tsv = tsv.merge(vlinecenter, left_on='line', right_index=True) tsv = tsv.merge(vlinecenter, left_on='line', right_index=True)
@ -274,7 +290,7 @@ def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest
if ned_rest_endpoint is not None: if ned_rest_endpoint is not None:
tsv = ned(tsv, ner_result, ned_rest_endpoint) tsv, _ = ned(tsv, ner_result, ned_rest_endpoint, threshold=ned_threshold)
tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False, mode='a', header=False) tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False, mode='a', header=False)
except requests.HTTPError as e: except requests.HTTPError as e:
@ -288,7 +304,8 @@ def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest
help="REST endpoint of sbb_ner service. See https://github.com/qurator-spk/sbb_ner for details.") help="REST endpoint of sbb_ner service. See https://github.com/qurator-spk/sbb_ner for details.")
@click.option('--ned-rest-endpoint', type=str, default=None, @click.option('--ned-rest-endpoint', type=str, default=None,
help="REST endpoint of sbb_ned service. See https://github.com/qurator-spk/sbb_ned for details.") help="REST endpoint of sbb_ned service. See https://github.com/qurator-spk/sbb_ned for details.")
def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint): @click.option('--ned-json-file', type=str, default=None)
def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, ned_json_file):
tsv = pd.read_csv(tsv_file, sep='\t', comment='#', quoting=3) tsv = pd.read_csv(tsv_file, sep='\t', comment='#', quoting=3)
@ -299,8 +316,13 @@ def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint):
if ned_rest_endpoint is not None: if ned_rest_endpoint is not None:
tsv = ned(tsv, ner_result, ned_rest_endpoint) tsv, ned_result = ned(tsv, ner_result, ned_rest_endpoint, return_full=ned_json_file is not None)
tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False, mode='a', header=False) if ned_json_file is not None:
with open(ned_json_file, "w") as fp_json:
json.dump(ned_result, fp_json, indent=2, separators=(',', ': '))
tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False)
except requests.HTTPError as e: except requests.HTTPError as e:
print(e) print(e)

Loading…
Cancel
Save