diff --git a/qurator/sbb_ner/webapp/app.py b/qurator/sbb_ner/webapp/app.py index 8820a69..ea5f4c4 100644 --- a/qurator/sbb_ner/webapp/app.py +++ b/qurator/sbb_ner/webapp/app.py @@ -220,6 +220,9 @@ def ner(model_id=None): for pos, (token, word_pred) in enumerate(zip(tokens, word_predictions)): + if word_pred == '[SEP]': + word_pred = 'O' + if not token.startswith('##') and token != '[UNK]': if len(word) > 0: output_sentence.append({'word': word, 'prediction': last_prediction}) @@ -244,9 +247,6 @@ def ner(model_id=None): if not token.startswith('##') and word_pred == 'X': word_pred = 'O' - if word_pred == '[SEP]': - word_pred = 'O' - token = token[2:] if token.startswith('##') else token word += token @@ -267,6 +267,7 @@ def ner(model_id=None): logger.warning('Input and output different!!! \n\n\nInput: {}\n\nOutput: {}\n'. format("".join(input_sentence).replace(" ", ""), "".join([pred['word'] for pred in output_sentence]))) + torch.cuda.empty_cache() return jsonify(output)