From 9bf2e6f51b8d8127079289000ec327dba43f6782 Mon Sep 17 00:00:00 2001 From: Kai Labusch Date: Fri, 22 Nov 2019 16:38:42 +0100 Subject: [PATCH] fix NER output; fix BERT Tokenizer --- .../sbb_ner/ground_truth/data_processor.py | 125 +++--- qurator/sbb_ner/models/bert.py | 4 +- qurator/sbb_ner/models/tokenization.py | 419 ++++++++++++++++++ qurator/sbb_ner/webapp/app.py | 51 ++- 4 files changed, 519 insertions(+), 80 deletions(-) create mode 100644 qurator/sbb_ner/models/tokenization.py diff --git a/qurator/sbb_ner/ground_truth/data_processor.py b/qurator/sbb_ner/ground_truth/data_processor.py index f1a3293..054848f 100644 --- a/qurator/sbb_ner/ground_truth/data_processor.py +++ b/qurator/sbb_ner/ground_truth/data_processor.py @@ -37,7 +37,8 @@ class InputExample(object): class InputFeatures(object): """A single set of features of data.""" - def __init__(self, input_ids, input_mask, segment_ids, label_id, tokens): + def __init__(self, guid, input_ids, input_mask, segment_ids, label_id, tokens): + self.guid = guid self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids @@ -74,6 +75,8 @@ class WikipediaDataset(Dataset): # noinspection PyUnresolvedReferences self._random_state = np.random.RandomState(seed=self._seed) + self._features = [] + self._reset() return @@ -85,9 +88,8 @@ class WikipediaDataset(Dataset): return int(self._counter) % int(1.0 / self._no_entity_fraction) != 0 - def __getitem__(self, index): + def _get_features(self): - del index if self._counter > self._data_epochs * self._epoch_size: self._reset() @@ -113,14 +115,24 @@ class WikipediaDataset(Dataset): sample = InputExample(guid="%s-%s" % (self._set_file, self._counter), text_a=sen_words, text_b=None, label=sen_tags) - features = convert_examples_to_features(sample, self._label_map, self._max_seq_length, self._tokenizer) + return [fe for fe in + convert_examples_to_features(sample, self._label_map, self._max_seq_length, self._tokenizer)] + + def __getitem__(self, index): + + del index + + if len(self._features) == 0: + self._features = self._get_features() + + fe = self._features.pop() self._counter += 1 - return torch.tensor(features.input_ids, dtype=torch.long), \ - torch.tensor(features.input_mask, dtype=torch.long), \ - torch.tensor(features.segment_ids, dtype=torch.long), \ - torch.tensor(features.label_id, dtype=torch.long) + return torch.tensor(fe.input_ids, dtype=torch.long), \ + torch.tensor(fe.input_mask, dtype=torch.long), \ + torch.tensor(fe.segment_ids, dtype=torch.long), \ + torch.tensor(fe.label_id, dtype=torch.long) def __len__(self): @@ -324,8 +336,8 @@ class NerProcessor(DataProcessor): sequential=False): if features is None: - features = [convert_examples_to_features(ex, label_map, max_seq_length, tokenizer) - for ex in examples] + features = [fe for ex in examples for fe in + convert_examples_to_features(ex, label_map, max_seq_length, tokenizer)] all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) @@ -362,74 +374,59 @@ class NerProcessor(DataProcessor): return data -def convert_examples_to_features(example, label_map, max_seq_length, tokenizer): +def convert_examples_to_features(example, label_map, max_seq_len, tokenizer): """ :param example: instance of InputExample - :param label_map: - :param max_seq_length: - :param tokenizer: + :param label_map: Maps labels like B-ORG ... to numbers (ids). + :param max_seq_len: Maximum length of sequences to be delivered to the model. + :param tokenizer: BERT-Tokenizer :return: """ - - words = example.text_a - word_labels = example.label tokens = [] labels = [] - for i, word in enumerate(words): + for i, word in enumerate(example.text_a): # example.text_a is a sequence of words token = tokenizer.tokenize(word) tokens.extend(token) - label_1 = word_labels[i] if i < len(word_labels) else 'O' + label_1 = example.label[i] if i < len(example.label) else 'O' - for m in range(len(token)): + for m in range(len(token)): # a word might have been split into several tokens if m == 0: labels.append(label_1) else: labels.append("X") - if len(tokens) >= max_seq_length - 1: - tokens = tokens[0:(max_seq_length - 2)] - labels = labels[0:(max_seq_length - 2)] - - n_tokens = [] - segment_ids = [] - label_ids = [] - n_tokens.append("[CLS]") - segment_ids.append(0) - label_ids.append(label_map["[CLS]"]) - for i, token in enumerate(tokens): - n_tokens.append(token) - segment_ids.append(0) - label_ids.append(label_map[labels[i]]) - n_tokens.append("[SEP]") - segment_ids.append(0) - label_ids.append(label_map["[SEP]"]) - input_ids = tokenizer.convert_tokens_to_ids(n_tokens) - input_mask = [1] * len(input_ids) - - while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) - label_ids.append(0) - - assert len(input_ids) == max_seq_length - assert len(input_mask) == max_seq_length - assert len(segment_ids) == max_seq_length - assert len(label_ids) == max_seq_length - - # if ex_index < 5: - # logger.info("*** Example ***") - # logger.info("guid: %s" % example.guid) - # logger.info("tokens: %s" % " ".join( - # [str(x) for x in tokens])) - # logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) - # logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) - # logger.info( - # "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) - # logger.info("label: %s (id = %d)" % (example.label, label_ids)) - - return InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_ids, - tokens=n_tokens) + start_pos = 0 + while start_pos < len(tokens): + + window_len = min(max_seq_len - 2, len(tokens) - start_pos) # -2 since we also need [CLS] and [SEP] + + # Make sure that we do not split the sentence within a word. + while window_len > 1 and start_pos + window_len < len(tokens) and\ + tokens[start_pos + window_len].startswith('##'): + window_len -= 1 + + token_window = tokens[start_pos:start_pos+window_len] + start_pos += window_len + + augmented_tokens = ["[CLS]"] + token_window + ["[SEP]"] + + input_ids = tokenizer.convert_tokens_to_ids(augmented_tokens) + max(0, max_seq_len - len(augmented_tokens))*[0] + + input_mask = [1] * len(augmented_tokens) + max(0, max_seq_len - len(augmented_tokens))*[0] + + segment_ids = [0] + len(token_window) * [0] + [0] + max(0, max_seq_len - len(augmented_tokens))*[0] + + label_ids = [label_map["[CLS]"]] + [label_map[labels[i]] for i in range(len(token_window))] + \ + [label_map["[SEP]"]] + max(0, max_seq_len - len(augmented_tokens)) * [0] + + assert len(input_ids) == max_seq_len + assert len(input_mask) == max_seq_len + assert len(segment_ids) == max_seq_len + assert len(label_ids) == max_seq_len + + yield InputFeatures(guid=example.guid, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, + label_id=label_ids, tokens=augmented_tokens) + diff --git a/qurator/sbb_ner/models/bert.py b/qurator/sbb_ner/models/bert.py index be18335..8e9602a 100644 --- a/qurator/sbb_ner/models/bert.py +++ b/qurator/sbb_ner/models/bert.py @@ -17,7 +17,8 @@ from pytorch_pretrained_bert.modeling import (CONFIG_NAME, # WEIGHTS_NAME, BertConfig, BertForTokenClassification) from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule -from pytorch_pretrained_bert.tokenization import BertTokenizer +# from pytorch_pretrained_bert.tokenization import BertTokenizer +from .tokenization import BertTokenizer from conlleval import evaluate as conll_eval @@ -386,6 +387,7 @@ def model_predict(dataloader, device, label_map, model): y_pred.append(temp_2) break else: + temp_2.pop() # skip last token since its [SEP] y_pred.append(temp_2) return y_pred diff --git a/qurator/sbb_ner/models/tokenization.py b/qurator/sbb_ner/models/tokenization.py new file mode 100644 index 0000000..67fae26 --- /dev/null +++ b/qurator/sbb_ner/models/tokenization.py @@ -0,0 +1,419 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import collections +import logging +import os +import unicodedata +from io import open + +from pytorch_pretrained_bert.file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, +} +VOCAB_NAME = 'vocab.txt' + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + return vocab_file + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): + logger.warning("The pre-trained model you are loading is a cased model but you have not set " + "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " + "you may want to check this behavior.") + kwargs['do_lower_case'] = False + elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): + logger.warning("The pre-trained model you are loading is an uncased model but you have set " + "`do_lower_case` to False. We are setting `do_lower_case=True` for you " + "but you may want to check this behavior.") + kwargs['do_lower_case'] = True + else: + vocab_file = pretrained_model_name_or_path + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + return tokenizer + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + # is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + # is_bad = True + # break + sub_tokens.append(self.unk_token) + start += 1 + else: + sub_tokens.append(cur_substr) + start = end + + # if is_bad: + # output_tokens.append(self.unk_token) + # else: + output_tokens.extend(sub_tokens) + + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/qurator/sbb_ner/webapp/app.py b/qurator/sbb_ner/webapp/app.py index 6de6d45..5d116d7 100644 --- a/qurator/sbb_ner/webapp/app.py +++ b/qurator/sbb_ner/webapp/app.py @@ -10,7 +10,7 @@ from somajo import Tokenizer, SentenceSplitter from qurator.sbb_ner.models.bert import get_device, model_predict from qurator.sbb_ner.ground_truth.data_processor import NerProcessor, convert_examples_to_features -from pytorch_pretrained_bert.tokenization import BertTokenizer +from qurator.sbb_ner.models.tokenization import BertTokenizer from pytorch_pretrained_bert.modeling import (CONFIG_NAME, BertConfig, BertForTokenClassification) @@ -90,10 +90,8 @@ class NERPredictor: examples = NerProcessor.create_examples(sentences, 'test') - features = [convert_examples_to_features(ex, self._label_to_id, self._max_seq_length, self._bert_tokenizer) - for ex in examples] - - assert len(sentences) == len(features) + features = [fe for ex in examples for fe in + convert_examples_to_features(ex, self._label_to_id, self._max_seq_length, self._bert_tokenizer)] data_loader = NerProcessor.make_data_loader(None, self._batch_size, self._local_rank, self._label_to_id, self._max_seq_length, self._bert_tokenizer, features=features, @@ -101,11 +99,22 @@ class NERPredictor: prediction_tmp = model_predict(data_loader, self._device, self._label_map, self._model) - assert len(sentences) == len(prediction_tmp) + assert len(prediction_tmp) == len(features) prediction = [] + prev_guid = None for fe, pr in zip(features, prediction_tmp): - prediction.append((fe.tokens[1:-1], pr)) + # longer sentences might have been processed in several steps + # therefore we have to glue them together. This can be done on the basis of the guid. + + if prev_guid != fe.guid: + prediction.append((fe.tokens[1:-1], pr)) + else: + prediction[-1] = (prediction[-1][0] + fe.tokens[1:-1], prediction[-1][1] + pr) + + prev_guid = fe.guid + + assert len(sentences) == len(prediction) return prediction @@ -243,23 +252,28 @@ def ner(model_id): output = [] - for tokens, word_predictions in prediction: + for (tokens, word_predictions), (input_sentence, _) in zip(prediction, sentences): - word = None + original_text = "".join(input_sentence) + + word = '' last_prediction = 'O' output_sentence = [] - for token, word_pred in zip(tokens, word_predictions): - - if token == '[UNK]': - continue + for pos, (token, word_pred) in enumerate(zip(tokens, word_predictions)): if not token.startswith('##'): - if word is not None: + if len(word) > 0: output_sentence.append({'word': word, 'prediction': last_prediction}) word = '' + if token == '[UNK]': + orig_pos = len("".join([pred['word'] for pred in output_sentence])) + + output_sentence.append({'word': original_text[orig_pos], 'prediction': 'O'}) + continue + token = token[2:] if token.startswith('##') else token word += token @@ -267,11 +281,18 @@ def ner(model_id): if word_pred != 'X': last_prediction = word_pred - if word is not None and len(word) > 0: + if len(word) > 0: output_sentence.append({'word': word, 'prediction': last_prediction}) output.append(output_sentence) + for output_sentence, (input_sentence, _) in zip(output, sentences): + + try: + assert "".join([pred['word'] for pred in output_sentence]) == "".join(input_sentence).replace(" ", "") + except AssertionError: + import ipdb;ipdb.set_trace() + return jsonify(output)