diff --git a/qurator/sbb_ner/webapp/app.py b/qurator/sbb_ner/webapp/app.py index 017c18f..fef1ea0 100644 --- a/qurator/sbb_ner/webapp/app.py +++ b/qurator/sbb_ner/webapp/app.py @@ -127,16 +127,19 @@ class PredictorStore: def get(self, model_id): - model = next((m for m in app.config['MODELS'] if m['id'] == int(model_id))) + if model_id is not None: + model = next((m for m in app.config['MODELS'] if m['id'] == int(model_id))) + else: + model = next((m for m in app.config['MODELS'] if m['default'])) - if self._model_id != model_id: + if self._model_id != model['id']: self._predictor = NERPredictor(model_dir=model['model_dir'], epoch=model['epoch'], batch_size=app.config['BATCH_SIZE'], no_cuda=False if not os.environ.get('USE_CUDA') else os.environ.get('USE_CUDA').lower() == 'false') - self._model_id = model_id + self._model_id = model['id'] return self._predictor @@ -168,8 +171,9 @@ def tokenized(): return jsonify(result) +@app.route('/ner-bert-tokens', methods=['GET', 'POST']) @app.route('/ner-bert-tokens/', methods=['GET', 'POST']) -def ner_bert_tokens(model_id): +def ner_bert_tokens(model_id=None): raw_text = request.json['text'] @@ -192,8 +196,9 @@ def ner_bert_tokens(model_id): return jsonify(output) +@app.route('/ner', methods=['GET', 'POST']) @app.route('/ner/', methods=['GET', 'POST']) -def ner(model_id): +def ner(model_id=None): raw_text = request.json['text']