diff --git a/cli.py b/cli.py index ea7255c..d8923bd 100644 --- a/cli.py +++ b/cli.py @@ -321,12 +321,14 @@ def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom'] + supported_ent = {'PER', 'LOC', 'ORG'} + if noproxy: os.environ['no_proxy'] = '*' tsv = pd.read_csv(tsv_file, sep='\t', comment='#', quoting=3) - parts = extract_doc_links(tsv_out_file) + parts = extract_doc_links(tsv_file) urls = [part['url'] for part in parts] @@ -342,7 +344,7 @@ def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, tmp = tsv.copy() tmp['sen'] = (tmp['No.'] == 0).cumsum() - ner_result = [[{'word': row.TOKEN, 'prediction': row['NE-TAG']} for _, row in sen.iterrows()] + ner_result = [[{'word': str(row.TOKEN), 'prediction': row['NE-TAG'] if len(row['NE-TAG']) == 1 or row['NE-TAG'][-3:] in supported_ent else 'O' } for _, row in sen.iterrows()] for _, sen in tmp.groupby('sen')] else: raise RuntimeError("Either NER rest endpoint or NER-TAG information within tsv_file required.") @@ -364,7 +366,6 @@ def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, pd.DataFrame([], columns=out_columns). to_csv(tsv_out_file, sep="\t", quoting=3, index=False) for url_id, part in tsv.groupby('url_id'): - with open(tsv_out_file, 'a') as f: f.write('# ' + urls[url_id] + '\n') @@ -383,4 +384,4 @@ def make_page2tsv_commands(xls_file): for _, row in df.iterrows(): print('page2tsv $(OPTIONS) {}.xml {}.tsv --image-url={} --scale-factor={}'. format(row.Filename, row.Filename, row.iiif_url.replace('/full/full', '/left,top,width,height/full'), - row.scale_factor)) \ No newline at end of file + row.scale_factor))