diff --git a/qurator/sbb_ner/webapp/app.py b/qurator/sbb_ner/webapp/app.py index 13953f8..911795c 100644 --- a/qurator/sbb_ner/webapp/app.py +++ b/qurator/sbb_ner/webapp/app.py @@ -260,6 +260,8 @@ def ner(model_id): for (tokens, word_predictions), (input_sentence, _) in zip(prediction, sentences): original_text = "".join(input_sentence).replace(" ", "") + original_word_positions = \ + [pos for positions in [[idx] * len(word) for idx, word in enumerate(input_sentence)] for pos in positions] word = '' last_prediction = 'O' @@ -274,8 +276,13 @@ def ner(model_id): word = '' if token == '[UNK]': + orig_pos = len("".join([pred['word'] for pred in output_sentence]) + word) + if orig_pos > 0 and original_word_positions[orig_pos-1] != original_word_positions[orig_pos]: + output_sentence.append({'word': word, 'prediction': last_prediction}) + word = '' + word += original_text[orig_pos] continue