diff --git a/cli.py b/cli.py index c0cb314..ea7255c 100644 --- a/cli.py +++ b/cli.py @@ -319,11 +319,17 @@ def page2tsv(page_xml_file, tsv_out_file, image_url, ner_rest_endpoint, ned_rest @click.option('--ned-threshold', type=float, default=None) def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, ned_json_file, noproxy, ned_threshold): + out_columns = ['No.', 'TOKEN', 'NE-TAG', 'NE-EMB', 'ID', 'url_id', 'left', 'right', 'top', 'bottom'] + if noproxy: os.environ['no_proxy'] = '*' tsv = pd.read_csv(tsv_file, sep='\t', comment='#', quoting=3) + parts = extract_doc_links(tsv_out_file) + + urls = [part['url'] for part in parts] + try: if ner_rest_endpoint is not None: @@ -350,8 +356,20 @@ def find_entities(tsv_file, tsv_out_file, ner_rest_endpoint, ned_rest_endpoint, with open(ned_json_file, "w") as fp_json: json.dump(ned_result, fp_json, indent=2, separators=(',', ': ')) - print('Writing to {}...'.format(tsv_out_file)) - tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False) + if len(urls) == 0: + print('Writing to {}...'.format(tsv_out_file)) + + tsv.to_csv(tsv_out_file, sep="\t", quoting=3, index=False) + else: + pd.DataFrame([], columns=out_columns). to_csv(tsv_out_file, sep="\t", quoting=3, index=False) + + for url_id, part in tsv.groupby('url_id'): + + with open(tsv_out_file, 'a') as f: + f.write('# ' + urls[url_id] + '\n') + + part.to_csv(tsv_out_file, sep="\t", quoting=3, index=False, mode='a', header=False) + except requests.HTTPError as e: print(e)