diff --git a/qurator/sbb_ner/ground_truth/data_processor.py b/qurator/sbb_ner/ground_truth/data_processor.py index 577448b..85245f8 100644 --- a/qurator/sbb_ner/ground_truth/data_processor.py +++ b/qurator/sbb_ner/ground_truth/data_processor.py @@ -388,6 +388,9 @@ def convert_examples_to_features(example, label_map, max_seq_len, tokenizer): for i, word in enumerate(example.text_a): # example.text_a is a sequence of words token = tokenizer.tokenize(word) + + # import ipdb;ipdb.set_trace() + tokens.extend(token) label_1 = example.label[i] if i < len(example.label) else 'O' diff --git a/qurator/sbb_ner/webapp/app.py b/qurator/sbb_ner/webapp/app.py index b1d566f..cbeb25b 100644 --- a/qurator/sbb_ner/webapp/app.py +++ b/qurator/sbb_ner/webapp/app.py @@ -219,57 +219,88 @@ def ner(model_id=None): output = [] - for (tokens, word_predictions), (input_sentence, _) in zip(prediction, sentences): + for (tokens, token_predictions), (input_sentence, _) in zip(prediction, sentences): + output_text = "" original_text = "".join(input_sentence) original_word_positions = \ [pos for positions in [[idx] * len(word) for idx, word in enumerate(input_sentence)] for pos in positions] word = '' - last_prediction = 'O' + word_prediction = 'O' output_sentence = [] - for pos, (token, word_pred) in enumerate(zip(tokens, word_predictions)): + for pos, (token, token_prediction) in enumerate(zip(tokens, token_predictions)): - if word_pred == '[SEP]': - word_pred = 'O' + if not token.startswith('##') and token_prediction == 'X' or token_prediction == '[SEP]': + token_prediction = 'O' - if not token.startswith('##') and token != '[UNK]': - if len(word) > 0: - output_sentence.append({'word': word, 'prediction': last_prediction}) + orig_pos = len(output_text + word) + # if the current word length is greater than 0 + # and its either a word start token (does not start with ##) and not an unknown token or the original text + # positions indicate a word break + if len(word) > 0 and ((not token.startswith('##') and token != '[UNK]') or + (orig_pos > 0 and + original_word_positions[orig_pos-1] != original_word_positions[orig_pos])): + output_sentence.append({'word': word, 'prediction': word_prediction}) + output_text += word word = '' + word_prediction = 'O' if token == '[UNK]': - orig_pos = len("".join([pred['word'] for pred in output_sentence]) + word) + orig_pos = len(output_text + word) + # are we on a word boundary? if orig_pos > 0 and original_word_positions[orig_pos-1] != original_word_positions[orig_pos]: - output_sentence.append({'word': word, 'prediction': last_prediction}) + + # we are on a word boundary - start a new word ... + output_sentence.append({'word': word, 'prediction': word_prediction}) + output_text += word word = '' + word_prediction = 'O' - word += original_text[orig_pos] + # get character that corresponds to [UNK] token from original text + token = original_text[orig_pos] - if word_pred != 'X': - last_prediction = word_pred + else: + token = token[2:] if token.startswith('##') else token - continue + # if the output_text plus the current word and token is not a prefix of the original text, it means, + # we would miss characters. Therefore we take the missing characters from the original text at the current + # word position + while not original_text.startswith(output_text + word + token) \ + and len(output_text + word) < len(original_text): - if not token.startswith('##') and word_pred == 'X': - word_pred = 'O' + word += original_text[len(output_text + word)] - token = token[2:] if token.startswith('##') else token + orig_pos = len(output_text + word) + + # are we on a word boundary? + if orig_pos > 0 and original_word_positions[orig_pos - 1] != original_word_positions[orig_pos]: + # we are on a word boundary - start a new word ... + output_sentence.append({'word': word, 'prediction': word_prediction}) + output_text += word + word = '' + word_prediction = 'O' word += token - if word_pred != 'X': - last_prediction = word_pred + if token_prediction != 'X': + word_prediction = token_prediction if len(word) > 0: - output_sentence.append({'word': word, 'prediction': last_prediction}) + output_text += word + output_sentence.append({'word': word, 'prediction': word_prediction}) output.append(output_sentence) + try: + assert output_text == original_text + except AssertionError: + import ipdb;ipdb.set_trace() + for output_sentence, (input_sentence, _) in zip(output, sentences): try: @@ -278,6 +309,7 @@ def ner(model_id=None): logger.warning('Input and output different!!! \n\n\nInput: {}\n\nOutput: {}\n'. format("".join(input_sentence).replace(" ", ""), "".join([pred['word'] for pred in output_sentence]))) + torch.cuda.empty_cache() return jsonify(output)