diff --git a/qurator/sbb_ner/webapp/app.py b/qurator/sbb_ner/webapp/app.py index fb86304..a627e91 100644 --- a/qurator/sbb_ner/webapp/app.py +++ b/qurator/sbb_ner/webapp/app.py @@ -1,6 +1,6 @@ import os import logging -from flask import Flask, send_from_directory, redirect, jsonify, request +from flask import Flask, send_from_directory, redirect, jsonify, request, send_file import pandas as pd from sqlite3 import Error import sqlite3 @@ -16,6 +16,11 @@ from pytorch_pretrained_bert.modeling import (CONFIG_NAME, BertConfig, BertForTokenClassification) +from qurator.sbb.xml import get_entity_coordinates + +import io +from PIL import Image, ImageDraw + app = Flask(__name__) app.config.from_json('config.json' if not os.environ.get('CONFIG') else os.environ.get('CONFIG')) @@ -207,7 +212,17 @@ def fulltext(ppn): df = digisam.get(ppn) if len(df) == 0: - return 'bad request!', 400 + + df = digisam.get('PPN' + ppn) + + if len(df) == 0: + + if ppn.startswith('PPN'): + df = digisam.get(ppn[3:]) + + if len(df) == 0: + + return 'bad request!', 400 text = '' for row_index, row_data in df.iterrows(): @@ -323,6 +338,53 @@ def ner(model_id): return jsonify(output) +def find_file(path, ppn, page, ending): + + file = (8 - len(str(page))) * '0' + page + + if os.path.exists("{}/{}/{}{}".format(path, ppn, file, ending)): + return "{}/{}/{}{}".format(path, ppn, file, ending) + elif os.path.exists("{}/PPN{}/{}{}".format(path, ppn, file, ending)): + return "{}/PPN{}/{}{}".format(path, ppn, file, ending) + elif ppn.startswith('PPN') and os.path.exists("{}/{}/{}{}".format(path, ppn[3:], file, ending)): + return "{}/{}/{}{}".format(path, ppn[3:], file, ending) + else: + return None + + +@app.route('/image//') +def get_image(ppn, page): + + image_file = find_file(app.config['IMAGE_PATH'], ppn, page, '.tif') + + if image_file is None: + return 'bad request!', 400 + + img = Image.open(image_file) + + img = img.convert('RGB') + + alto_file = find_file(app.config['ALTO_PATH'], ppn, page, '.xml') + + if alto_file is not None: + + ner_coordinates, entity_map = get_entity_coordinates(alto_file, img) + + draw = ImageDraw.Draw(img, 'RGBA') + + for idx, row in ner_coordinates.iterrows(): + + draw.rectangle(xy=((row.x0, row.y0), (row.x1, row.y1)), + fill=(255 if row.ner_id.startswith('PER') else 0, + 255 if row.ner_id.startswith('LOC') else 0, + 255 if row.ner_id.startswith('ORG') else 0, 50)) + buffer = io.BytesIO() + img.save(buffer, "JPEG") + buffer.seek(0) + + return send_file(buffer, mimetype='image/jpeg') + + @app.route('/') def send_js(path): return send_from_directory('static', path)