|
|
|
@ -219,57 +219,88 @@ def ner(model_id=None):
|
|
|
|
|
|
|
|
|
|
output = []
|
|
|
|
|
|
|
|
|
|
for (tokens, word_predictions), (input_sentence, _) in zip(prediction, sentences):
|
|
|
|
|
for (tokens, token_predictions), (input_sentence, _) in zip(prediction, sentences):
|
|
|
|
|
|
|
|
|
|
output_text = ""
|
|
|
|
|
original_text = "".join(input_sentence)
|
|
|
|
|
original_word_positions = \
|
|
|
|
|
[pos for positions in [[idx] * len(word) for idx, word in enumerate(input_sentence)] for pos in positions]
|
|
|
|
|
|
|
|
|
|
word = ''
|
|
|
|
|
last_prediction = 'O'
|
|
|
|
|
word_prediction = 'O'
|
|
|
|
|
output_sentence = []
|
|
|
|
|
|
|
|
|
|
for pos, (token, word_pred) in enumerate(zip(tokens, word_predictions)):
|
|
|
|
|
for pos, (token, token_prediction) in enumerate(zip(tokens, token_predictions)):
|
|
|
|
|
|
|
|
|
|
if word_pred == '[SEP]':
|
|
|
|
|
word_pred = 'O'
|
|
|
|
|
if not token.startswith('##') and token_prediction == 'X' or token_prediction == '[SEP]':
|
|
|
|
|
token_prediction = 'O'
|
|
|
|
|
|
|
|
|
|
if not token.startswith('##') and token != '[UNK]':
|
|
|
|
|
if len(word) > 0:
|
|
|
|
|
output_sentence.append({'word': word, 'prediction': last_prediction})
|
|
|
|
|
orig_pos = len(output_text + word)
|
|
|
|
|
|
|
|
|
|
# if the current word length is greater than 0
|
|
|
|
|
# and its either a word start token (does not start with ##) and not an unknown token or the original text
|
|
|
|
|
# positions indicate a word break
|
|
|
|
|
if len(word) > 0 and ((not token.startswith('##') and token != '[UNK]') or
|
|
|
|
|
(orig_pos > 0 and
|
|
|
|
|
original_word_positions[orig_pos-1] != original_word_positions[orig_pos])):
|
|
|
|
|
output_sentence.append({'word': word, 'prediction': word_prediction})
|
|
|
|
|
output_text += word
|
|
|
|
|
word = ''
|
|
|
|
|
word_prediction = 'O'
|
|
|
|
|
|
|
|
|
|
if token == '[UNK]':
|
|
|
|
|
|
|
|
|
|
orig_pos = len("".join([pred['word'] for pred in output_sentence]) + word)
|
|
|
|
|
orig_pos = len(output_text + word)
|
|
|
|
|
|
|
|
|
|
# are we on a word boundary?
|
|
|
|
|
if orig_pos > 0 and original_word_positions[orig_pos-1] != original_word_positions[orig_pos]:
|
|
|
|
|
output_sentence.append({'word': word, 'prediction': last_prediction})
|
|
|
|
|
|
|
|
|
|
# we are on a word boundary - start a new word ...
|
|
|
|
|
output_sentence.append({'word': word, 'prediction': word_prediction})
|
|
|
|
|
output_text += word
|
|
|
|
|
word = ''
|
|
|
|
|
word_prediction = 'O'
|
|
|
|
|
|
|
|
|
|
word += original_text[orig_pos]
|
|
|
|
|
# get character that corresponds to [UNK] token from original text
|
|
|
|
|
token = original_text[orig_pos]
|
|
|
|
|
|
|
|
|
|
if word_pred != 'X':
|
|
|
|
|
last_prediction = word_pred
|
|
|
|
|
else:
|
|
|
|
|
token = token[2:] if token.startswith('##') else token
|
|
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
# if the output_text plus the current word and token is not a prefix of the original text, it means,
|
|
|
|
|
# we would miss characters. Therefore we take the missing characters from the original text at the current
|
|
|
|
|
# word position
|
|
|
|
|
while not original_text.startswith(output_text + word + token) \
|
|
|
|
|
and len(output_text + word) < len(original_text):
|
|
|
|
|
|
|
|
|
|
if not token.startswith('##') and word_pred == 'X':
|
|
|
|
|
word_pred = 'O'
|
|
|
|
|
word += original_text[len(output_text + word)]
|
|
|
|
|
|
|
|
|
|
token = token[2:] if token.startswith('##') else token
|
|
|
|
|
orig_pos = len(output_text + word)
|
|
|
|
|
|
|
|
|
|
# are we on a word boundary?
|
|
|
|
|
if orig_pos > 0 and original_word_positions[orig_pos - 1] != original_word_positions[orig_pos]:
|
|
|
|
|
# we are on a word boundary - start a new word ...
|
|
|
|
|
output_sentence.append({'word': word, 'prediction': word_prediction})
|
|
|
|
|
output_text += word
|
|
|
|
|
word = ''
|
|
|
|
|
word_prediction = 'O'
|
|
|
|
|
|
|
|
|
|
word += token
|
|
|
|
|
|
|
|
|
|
if word_pred != 'X':
|
|
|
|
|
last_prediction = word_pred
|
|
|
|
|
if token_prediction != 'X':
|
|
|
|
|
word_prediction = token_prediction
|
|
|
|
|
|
|
|
|
|
if len(word) > 0:
|
|
|
|
|
output_sentence.append({'word': word, 'prediction': last_prediction})
|
|
|
|
|
output_text += word
|
|
|
|
|
output_sentence.append({'word': word, 'prediction': word_prediction})
|
|
|
|
|
|
|
|
|
|
output.append(output_sentence)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
assert output_text == original_text
|
|
|
|
|
except AssertionError:
|
|
|
|
|
import ipdb;ipdb.set_trace()
|
|
|
|
|
|
|
|
|
|
for output_sentence, (input_sentence, _) in zip(output, sentences):
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
@ -278,6 +309,7 @@ def ner(model_id=None):
|
|
|
|
|
logger.warning('Input and output different!!! \n\n\nInput: {}\n\nOutput: {}\n'.
|
|
|
|
|
format("".join(input_sentence).replace(" ", ""),
|
|
|
|
|
"".join([pred['word'] for pred in output_sentence])))
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
return jsonify(output)
|
|
|
|
|