mirror of
https://github.com/qurator-spk/sbb_ner.git
synced 2025-07-27 19:59:53 +02:00
add default model behaviour
This commit is contained in:
parent
181cbb9f53
commit
ba188d1daa
1 changed files with 10 additions and 5 deletions
|
@ -127,16 +127,19 @@ class PredictorStore:
|
||||||
|
|
||||||
def get(self, model_id):
|
def get(self, model_id):
|
||||||
|
|
||||||
model = next((m for m in app.config['MODELS'] if m['id'] == int(model_id)))
|
if model_id is not None:
|
||||||
|
model = next((m for m in app.config['MODELS'] if m['id'] == int(model_id)))
|
||||||
|
else:
|
||||||
|
model = next((m for m in app.config['MODELS'] if m['default']))
|
||||||
|
|
||||||
if self._model_id != model_id:
|
if self._model_id != model['id']:
|
||||||
|
|
||||||
self._predictor = NERPredictor(model_dir=model['model_dir'],
|
self._predictor = NERPredictor(model_dir=model['model_dir'],
|
||||||
epoch=model['epoch'],
|
epoch=model['epoch'],
|
||||||
batch_size=app.config['BATCH_SIZE'],
|
batch_size=app.config['BATCH_SIZE'],
|
||||||
no_cuda=False if not os.environ.get('USE_CUDA') else
|
no_cuda=False if not os.environ.get('USE_CUDA') else
|
||||||
os.environ.get('USE_CUDA').lower() == 'false')
|
os.environ.get('USE_CUDA').lower() == 'false')
|
||||||
self._model_id = model_id
|
self._model_id = model['id']
|
||||||
|
|
||||||
return self._predictor
|
return self._predictor
|
||||||
|
|
||||||
|
@ -168,8 +171,9 @@ def tokenized():
|
||||||
return jsonify(result)
|
return jsonify(result)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/ner-bert-tokens', methods=['GET', 'POST'])
|
||||||
@app.route('/ner-bert-tokens/<model_id>', methods=['GET', 'POST'])
|
@app.route('/ner-bert-tokens/<model_id>', methods=['GET', 'POST'])
|
||||||
def ner_bert_tokens(model_id):
|
def ner_bert_tokens(model_id=None):
|
||||||
|
|
||||||
raw_text = request.json['text']
|
raw_text = request.json['text']
|
||||||
|
|
||||||
|
@ -192,8 +196,9 @@ def ner_bert_tokens(model_id):
|
||||||
return jsonify(output)
|
return jsonify(output)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/ner', methods=['GET', 'POST'])
|
||||||
@app.route('/ner/<model_id>', methods=['GET', 'POST'])
|
@app.route('/ner/<model_id>', methods=['GET', 'POST'])
|
||||||
def ner(model_id):
|
def ner(model_id=None):
|
||||||
|
|
||||||
raw_text = request.json['text']
|
raw_text = request.json['text']
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue