1
0
Fork 0
mirror of https://github.com/qurator-spk/sbb_ner.git synced 2025-06-09 20:30:01 +02:00

re-structure repo

This commit is contained in:
Kai Labusch 2019-08-16 15:22:13 +02:00
commit 16e63b4673
24 changed files with 13072 additions and 0 deletions

View file

@ -0,0 +1 @@
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1 @@
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1,77 @@
import pandas as pd
import click
import codecs
import os
def read_gt(files, datasets):
sentence_number = 300000
gt_data = list()
for filename, dataset in zip(files, datasets):
gt_lines = [l.strip() for l in codecs.open(filename, 'r', 'latin-1')]
word_number = 0
for li in gt_lines:
if li == '':
if word_number > 0:
sentence_number += 1
word_number = 0
continue
if li.startswith('-DOCSTART-'):
continue
parts = li.split()
if len(parts) == 5:
word, _, _, _, tag = li.split()
else:
word, _, _, tag = li.split()
tag = tag.upper()
tag = tag.replace('_', '-')
tag = tag.replace('.', '-')
if tag not in {'B-LOC', 'B-PER', 'I-PER', 'I-ORG', 'B-ORG', 'I-LOC'}:
tag = 'O'
gt_data.append((sentence_number, word_number, word, tag, dataset))
word_number += 1
return pd.DataFrame(gt_data, columns=['nsentence', 'nword', 'word', 'tag', 'dataset'])
@click.command()
@click.argument('path-to-conll', type=click.Path(exists=True), required=True, nargs=1)
@click.argument('conll-ground-truth-file', type=click.Path(), required=True, nargs=1)
def main(path_to_conll, conll_ground_truth_file):
"""
Read CONLL 2003 ner ground truth files from directory <path-to-conll> and
write the outcome of the data parsing to some pandas DataFrame
that is stored as pickle in file <conll-ground-truth-file>.
"""
os.makedirs(os.path.dirname(conll_ground_truth_file), exist_ok=True)
gt_all = read_gt(['{}/deu.dev'.format(path_to_conll),
'{}/deu.testa'.format(path_to_conll),
'{}/deu.testb'.format(path_to_conll),
'{}/deu.train'.format(path_to_conll),
'{}/eng.testa'.format(path_to_conll),
'{}/eng.testb'.format(path_to_conll),
'{}/eng.train'.format(path_to_conll)],
['DE-CONLL-DEV', 'DE-CONLL-TESTA', 'DE-CONLL-TESTB', 'DE-CONLL-TRAIN',
'EN-CONLL-TESTA', 'EN-CONLL-TESTB', 'EN-CONLL-TRAIN'])
gt_all.to_pickle(conll_ground_truth_file)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,435 @@
from __future__ import absolute_import, division, print_function
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset, Dataset)
from torch.utils.data.distributed import DistributedSampler
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_id, tokens):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.tokens = tokens
class WikipediaDataset(Dataset):
"""
"""
def __init__(self, set_file, gt_file, data_epochs, epoch_size,
label_map, tokenizer, max_seq_length,
queue_size=1000, no_entity_fraction=0.0, seed=23,
min_sen_len=10, min_article_len=20):
self._set_file = set_file
self._subset = pd.read_pickle(set_file)
self._gt_file = gt_file
self._data_epochs = data_epochs
self._epoch_size = epoch_size
self._label_map = label_map
self._tokenizer = tokenizer
self._max_seq_length = max_seq_length
self._queue_size = queue_size
self._no_entity_fraction = no_entity_fraction
self._seed = seed
self._min_sen_len = min_sen_len
self._min_article_len = min_article_len
self._queue = None
self._data_sequence = None
self._counter = None
# noinspection PyUnresolvedReferences
self._random_state = np.random.RandomState(seed=self._seed)
self._reset()
return
def _next_sample_should_have_entities(self):
if self._no_entity_fraction <= 0.0:
return True
return int(self._counter) % int(1.0 / self._no_entity_fraction) != 0
def __getitem__(self, index):
del index
if self._counter > self._data_epochs * self._epoch_size:
self._reset()
while True:
# get next random sentence
sen_words, sen_tags = self._queue_next()
if len(sen_words) < self._min_sen_len: # Skip all sentences that are to short.
continue
if self._has_entities(sen_tags):
if not self._next_sample_should_have_entities(): # Skip sample if next sample is supposed to
# be a no-entity sample
continue
else:
if self._next_sample_should_have_entities(): # Skip sample if next sample is supposed to be a entity
# sample
continue
break
sample = InputExample(guid="%s-%s" % (self._set_file, self._counter),
text_a=sen_words, text_b=None, label=sen_tags)
features = convert_examples_to_features(sample, self._label_map, self._max_seq_length, self._tokenizer)
self._counter += 1
return torch.tensor(features.input_ids, dtype=torch.long), \
torch.tensor(features.input_mask, dtype=torch.long), \
torch.tensor(features.segment_ids, dtype=torch.long), \
torch.tensor(features.label_id, dtype=torch.long)
def __len__(self):
return int(self._epoch_size)
def _reset(self):
# print('================= WikipediaDataset:_reset ====================== ')
self._queue = list()
self._data_sequence = self._sequence()
self._counter = 0
# noinspection PyUnresolvedReferences
# self._random_state = np.random.RandomState(seed=self._seed)
for _ in range(0, self._queue_size):
self._queue.append(list())
def _sequence(self):
while True:
for row in pd.read_csv(self._gt_file, chunksize=1, sep=';'):
page_id = row.page_id.iloc[0]
text = row.text.iloc[0]
tags = row.tags.iloc[0]
if page_id not in self._subset.index:
continue
sentences = [(sen_text, sen_tag) for sen_text, sen_tag in zip(json.loads(text), json.loads(tags))]
if len(sentences) < self._min_article_len: # Skip very short articles.
continue
print(page_id)
yield sentences
def _queue_next(self):
nqueue = self._random_state.randint(len(self._queue))
while len(self._queue[nqueue]) <= 0:
self._queue[nqueue] = next(self._data_sequence)
return self._queue[nqueue].pop()
@staticmethod
def _has_entities(sen_tags):
for t in sen_tags:
if t != 'O':
return True
return False
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, batch_size, local_rank):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, batch_size, local_rank):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
def get_evaluation_file(self):
raise NotImplementedError()
class WikipediaNerProcessor(DataProcessor):
def __init__(self, train_sets, dev_sets, test_sets, gt_file, max_seq_length, tokenizer,
data_epochs, epoch_size, **kwargs):
del kwargs
self._max_seq_length = max_seq_length
self._tokenizer = tokenizer
self._train_set_file = train_sets
self._dev_set_file = dev_sets
self._test_set_file = test_sets
self._gt_file = gt_file
self._data_epochs = data_epochs
self._epoch_size = epoch_size
def get_train_examples(self, batch_size, local_rank):
"""See base class."""
return self._make_data_loader(self._train_set_file, batch_size, local_rank)
def get_dev_examples(self, batch_size, local_rank):
"""See base class."""
return self._make_data_loader(self._dev_set_file, batch_size, local_rank)
def get_labels(self):
"""See base class."""
labels = ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG", "X", "[CLS]", "[SEP]"]
return {label: i for i, label in enumerate(labels)}
def get_evaluation_file(self):
dev_set_name = os.path.splitext(os.path.basename(self._dev_set_file))[0]
return "eval_results-{}.pkl".format(dev_set_name)
def _make_data_loader(self, set_file, batch_size, local_rank):
del local_rank
data = WikipediaDataset(set_file=set_file, gt_file=self._gt_file,
data_epochs=self._data_epochs, epoch_size=self._epoch_size,
label_map=self.get_labels(), tokenizer=self._tokenizer,
max_seq_length=self._max_seq_length)
sampler = SequentialSampler(data)
return DataLoader(data, sampler=sampler, batch_size=batch_size)
class NerProcessor(DataProcessor):
def __init__(self, train_sets, dev_sets, test_sets, max_seq_length, tokenizer,
label_map=None, gt=None, gt_file=None, **kwargs):
del kwargs
self._max_seg_length = max_seq_length
self._tokenizer = tokenizer
self._train_sets = set(train_sets.split('|')) if train_sets is not None else set()
self._dev_sets = set(dev_sets.split('|')) if dev_sets is not None else set()
self._test_sets = set(test_sets.split('|')) if test_sets is not None else set()
self._gt = gt
if self._gt is None:
self._gt = pd.read_pickle(gt_file)
self._label_map = label_map
print('TRAIN SETS: ', train_sets)
print('DEV SETS: ', dev_sets)
print('TEST SETS: ', test_sets)
def get_train_examples(self, batch_size, local_rank):
"""See base class."""
return self.make_data_loader(
self.create_examples(self._read_lines(self._train_sets), "train"), batch_size, local_rank,
self.get_labels(), self._max_seg_length, self._tokenizer)
def get_dev_examples(self, batch_size, local_rank):
"""See base class."""
return self.make_data_loader(
self.create_examples(self._read_lines(self._dev_sets), "dev"), batch_size, local_rank,
self.get_labels(), self._max_seg_length, self._tokenizer)
def get_labels(self):
"""See base class."""
if self._label_map is not None:
return self._label_map
gt = self._gt
gt = gt.loc[gt.dataset.isin(self._train_sets.union(self._dev_sets).union(self._test_sets))]
labels = sorted(gt.tag.unique().tolist()) + ["X", "[CLS]", "[SEP]"]
self._label_map = {label: i for i, label in enumerate(labels, 1)}
self._label_map['UNK'] = 0
return self._label_map
def get_evaluation_file(self):
return "eval_results-{}.pkl".format("-".join(sorted(self._dev_sets)))
@staticmethod
def create_examples(lines, set_type):
for i, (sentence, label) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = sentence
text_b = None
label = label
yield InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
@staticmethod
def make_data_loader(examples, batch_size, local_rank, label_map, max_seq_length, tokenizer, features=None,
sequential=False):
if features is None:
features = [convert_examples_to_features(ex, label_map, max_seq_length, tokenizer)
for ex in examples]
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if local_rank == -1:
if sequential:
train_sampler = SequentialSampler(data)
else:
train_sampler = RandomSampler(data)
else:
if sequential:
train_sampler = SequentialSampler(data)
else:
train_sampler = DistributedSampler(data)
return DataLoader(data, sampler=train_sampler, batch_size=batch_size)
def _read_lines(self, sets):
gt = self._gt
gt = gt.loc[gt.dataset.isin(sets)]
data = list()
for i, sent in gt.groupby('nsentence'):
sent = sent.sort_values('nword', ascending=True)
data.append((sent.word.tolist(), sent.tag.tolist()))
return data
def convert_examples_to_features(example, label_map, max_seq_length, tokenizer):
"""
:param example: instance of InputExample
:param label_map:
:param max_seq_length:
:param tokenizer:
:return:
"""
words = example.text_a
word_labels = example.label
tokens = []
labels = []
for i, word in enumerate(words):
token = tokenizer.tokenize(word)
tokens.extend(token)
label_1 = word_labels[i] if i < len(word_labels) else 'O'
for m in range(len(token)):
if m == 0:
labels.append(label_1)
else:
labels.append("X")
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
n_tokens = []
segment_ids = []
label_ids = []
n_tokens.append("[CLS]")
segment_ids.append(0)
label_ids.append(label_map["[CLS]"])
for i, token in enumerate(tokens):
n_tokens.append(token)
segment_ids.append(0)
label_ids.append(label_map[labels[i]])
n_tokens.append("[SEP]")
segment_ids.append(0)
label_ids.append(label_map["[SEP]"])
input_ids = tokenizer.convert_tokens_to_ids(n_tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length
# if ex_index < 5:
# logger.info("*** Example ***")
# logger.info("guid: %s" % example.guid)
# logger.info("tokens: %s" % " ".join(
# [str(x) for x in tokens]))
# logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
# logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
# logger.info(
# "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
# logger.info("label: %s (id = %d)" % (example.label, label_ids))
return InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_ids,
tokens=n_tokens)

View file

@ -0,0 +1,70 @@
import pandas as pd
import re
import click
import os
def read_gt(files, datasets):
sentence_number = 100000
sentence = ''
gt_data = list()
for filename, dataset in zip(files, datasets):
gt_lines = [l.strip() for l in open(filename) if not l.startswith('<--')]
word_number = 0
for l in gt_lines:
try:
word, tag = l.split(' ')
except ValueError:
word = l.replace(' ', '_')
tag = 'O'
tag = tag.upper()
tag = tag.replace('_', '-')
tag = tag.replace('.', '-')
if tag not in {'B-LOC', 'B-PER', 'I-PER', 'I-ORG', 'B-ORG', 'I-LOC'}:
tag = 'O'
gt_data.append((sentence_number, word_number, word, tag, dataset))
if re.match(r'.*[.|?|!]$', word) \
and not re.match(r'[0-9]+[.]$', word) \
and not re.match(r'.*[0-9]+\s*$', sentence)\
and not re.match(r'.*\s+[\S]{1,2}$', sentence):
sentence_number += 1
sentence = ''
word_number = 0
else:
word_number += 1
sentence += ' ' + word
return pd.DataFrame(gt_data, columns=['nsentence', 'nword', 'word', 'tag', 'dataset'])
@click.command()
@click.argument('path-to-ner-corpora', type=click.Path(exists=True), required=True, nargs=1)
@click.argument('ner-ground-truth-file', type=click.Path(), required=True, nargs=1)
def main(path_to_ner_corpora, ner_ground_truth_file):
"""
Read europeana historic ner ground truth .bio files from directory <path-to-ner-corpora> and
write the outcome of the data parsing to some pandas DataFrame
that is stored as pickle in file <ner-ground-truth-file>.
"""
os.makedirs(os.path.dirname(ner_ground_truth_file), exist_ok=True)
gt_all = read_gt(['{}/enp_DE.sbb.bio/enp_DE.sbb.bio'.format(path_to_ner_corpora),
'{}/enp_DE.onb.bio/enp_DE.onb.bio'.format(path_to_ner_corpora),
'{}/enp_DE.lft.bio/enp_DE.lft.bio'.format(path_to_ner_corpora)], ['SBB', 'ONB', 'LFT'])
gt_all.to_pickle(ner_ground_truth_file)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,68 @@
import pandas as pd
import click
import os
def read_gt(files, datasets):
sentence_number = 200000
gt_data = list()
for filename, dataset in zip(files, datasets):
gt_lines = [l.strip() for l in open(filename)]
word_number = 0
for li in gt_lines:
if li == '':
if word_number > 0:
sentence_number += 1
word_number = 0
continue
if li.startswith('#'):
continue
_, word, tag, _ = li.split()
tag = tag.upper()
tag = tag.replace('_', '-')
tag = tag.replace('.', '-')
if len(tag) > 5:
tag = tag[0:5]
if tag not in {'B-LOC', 'B-PER', 'I-PER', 'I-ORG', 'B-ORG', 'I-LOC'}:
tag = 'O'
gt_data.append((sentence_number, word_number, word, tag, dataset))
word_number += 1
return pd.DataFrame(gt_data, columns=['nsentence', 'nword', 'word', 'tag', 'dataset'])
@click.command()
@click.argument('path-to-germ-eval', type=click.Path(exists=True), required=True, nargs=1)
@click.argument('germ-eval-ground-truth-file', type=click.Path(), required=True, nargs=1)
def main(path_to_germ_eval, germ_eval_ground_truth_file):
"""
Read germ eval .tsv files from directory <path-to-germ-eval> and
write the outcome of the data parsing to some pandas DataFrame
that is stored as pickle in file <germ-eval-ground-truth-file>.
"""
os.makedirs(os.path.dirname(germ_eval_ground_truth_file), exist_ok=True)
gt_all = read_gt(['{}/NER-de-dev.tsv'.format(path_to_germ_eval),
'{}/NER-de-test.tsv'.format(path_to_germ_eval),
'{}/NER-de-train.tsv'.format(path_to_germ_eval)],
['GERM-EVAL-DEV', 'GERM-EVAL-TEST', 'GERM-EVAL-TRAIN'])
gt_all.to_pickle(germ_eval_ground_truth_file)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,29 @@
import pandas as pd
import click
import os
@click.command()
@click.argument('files', nargs=-1, type=click.Path())
def main(files):
"""
Join multiple pandas DataFrame pickles of NER ground-truth into one big file.
"""
assert(len(files) > 1)
gt = list()
for filename in files[:-1]:
gt.append(pd.read_pickle(filename))
gt = pd.concat(gt, axis=0)
os.makedirs(os.path.dirname(files[-1]), exist_ok=True)
gt.to_pickle(files[-1])
if __name__ == '__main__':
main()

View file

@ -0,0 +1,68 @@
import pandas as pd
import click
import os
def read_gt(files, datasets):
sentence_number = 1000000
gt_data = list()
for filename, dataset in zip(files, datasets):
for li in open(filename, encoding='iso-8859-1'):
li = li.strip()
parts = li.split(' ')
prev_tag = 'O'
for word_number, pa in enumerate(parts):
if len(pa) == 0:
continue
word, pos, tag = pa.split('|')
tag = tag.upper()
tag = tag.replace('_', '-')
tag = tag.replace('.', '-')
if len(tag) > 5:
tag = tag[0:5]
if tag not in {'B-LOC', 'B-PER', 'I-PER', 'I-ORG', 'B-ORG', 'I-LOC'}:
tag = 'O'
if tag.startswith('I') and prev_tag == 'O':
tag = 'B' + tag[1:]
prev_tag = tag
gt_data.append((sentence_number, word_number, word, tag, dataset))
sentence_number += 1
return pd.DataFrame(gt_data, columns=['nsentence', 'nword', 'word', 'tag', 'dataset'])
@click.command()
@click.argument('path-to-wikiner', type=click.Path(exists=True), required=True, nargs=1)
@click.argument('wikiner-ground-truth-file', type=click.Path(), required=True, nargs=1)
def main(path_to_wikiner, wikiner_ground_truth_file):
"""
Read wikiner files from directory <path-to-wikiner> and
write the outcome of the data parsing to some pandas DataFrame
that is stored as pickle in file <wikiner-ground-truth-file>.
"""
os.makedirs(os.path.dirname(wikiner_ground_truth_file), exist_ok=True)
gt_all = read_gt(['{}/aij-wikiner-de-wp2'.format(path_to_wikiner),
'{}/aij-wikiner-de-wp3'.format(path_to_wikiner)],
['WIKINER-WP2', 'WIKINER-WP3'])
gt_all.to_pickle(wikiner_ground_truth_file)
if __name__ == '__main__':
main()

View file

@ -0,0 +1 @@
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1,693 @@
from __future__ import absolute_import, division, print_function
# from inspect import currentframe
import argparse
import logging
import os
import random
import json
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import (CONFIG_NAME, # WEIGHTS_NAME,
BertConfig,
BertForTokenClassification)
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from pytorch_pretrained_bert.tokenization import BertTokenizer
from conlleval import evaluate as conll_eval
from tqdm import tqdm, trange
from qurator.sbb_ner.ground_truth.data_processor import NerProcessor, WikipediaNerProcessor
from sklearn.model_selection import GroupKFold
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def model_train(bert_model, max_seq_length, do_lower_case,
num_train_epochs, train_batch_size, gradient_accumulation_steps,
learning_rate, weight_decay, loss_scale, warmup_proportion,
processor, device, n_gpu, fp16, cache_dir, local_rank,
dry_run, no_cuda, output_dir=None):
label_map = processor.get_labels()
if gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
gradient_accumulation_steps))
train_batch_size = train_batch_size // gradient_accumulation_steps
train_dataloader = processor.get_train_examples(train_batch_size, local_rank)
# Batch sampler divides by batch_size!
num_train_optimization_steps = int(len(train_dataloader)*num_train_epochs/gradient_accumulation_steps)
if local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
cache_dir = cache_dir if cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
'distributed_{}'.format(local_rank))
model = BertForTokenClassification.from_pretrained(bert_model, cache_dir=cache_dir, num_labels=len(label_map))
if fp16:
model.half()
model.to(device)
if local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=warmup_proportion, t_total=num_train_optimization_steps)
else:
optimizer = BertAdam(optimizer_grouped_parameters, lr=learning_rate, warmup=warmup_proportion,
t_total=num_train_optimization_steps)
warmup_linear = None
global_step = 0
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataloader))
logger.info(" Batch size = %d", train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
logger.info(" Num epochs = %d", num_train_epochs)
model_config = {"bert_model": bert_model, "do_lower": do_lower_case,
"max_seq_length": max_seq_length, "label_map": label_map}
def save_model(lh):
if output_dir is None:
return
output_model_file = os.path.join(output_dir, "pytorch_model_ep{}.bin".format(ep))
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f:
f.write(model_to_save.config.to_json_string())
json.dump(model_config, open(os.path.join(output_dir, "model_config.json"), "w"))
lh = pd.DataFrame(lh, columns=['global_step', 'loss'])
loss_history_file = os.path.join(output_dir, "loss_ep{}.pkl".format(ep))
lh.to_pickle(loss_history_file)
def load_model(epoch):
if output_dir is None:
return False
output_model_file = os.path.join(output_dir, "pytorch_model_ep{}.bin".format(epoch))
if not os.path.exists(output_model_file):
return False
logger.info("Loading epoch {} from disk...".format(epoch))
model.load_state_dict(torch.load(output_model_file,
map_location=lambda storage, loc: storage if no_cuda else None))
return True
model.train()
for ep in trange(1, int(num_train_epochs) + 1, desc="Epoch"):
if dry_run and ep > 1:
logger.info("Dry run. Stop.")
break
if load_model(ep):
global_step += len(train_dataloader) // gradient_accumulation_steps
continue
loss_history = list()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
with tqdm(total=len(train_dataloader), desc=f"Epoch {ep}") as pbar:
for step, batch in enumerate(train_dataloader):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
if fp16:
optimizer.backward(loss)
else:
loss.backward()
loss_history.append((global_step, loss.item()))
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
pbar.update(1)
mean_loss = tr_loss * gradient_accumulation_steps / nb_tr_steps
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
if dry_run and len(loss_history) > 2:
logger.info("Dry run. Stop.")
break
if (step + 1) % gradient_accumulation_steps == 0:
if fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = learning_rate * warmup_linear.get_lr(global_step, warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
save_model(loss_history)
return model, model_config
def model_eval(batch_size, label_map, processor, device, num_train_epochs=1, output_dir=None, model=None,
local_rank=-1, no_cuda=False, dry_run=False):
output_eval_file = None
if output_dir is not None:
output_eval_file = os.path.join(output_dir, processor.get_evaluation_file())
logger.info('Write evaluation results to: {}'.format(output_eval_file))
dataloader = processor.get_dev_examples(batch_size, local_rank)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(dataloader))
logger.info(" Batch size = %d", batch_size)
results = list()
output_config_file = None
if output_dir is not None:
output_config_file = os.path.join(output_dir, CONFIG_NAME)
for ep in trange(1, int(num_train_epochs) + 1, desc="Epoch"):
if dry_run and ep > 1:
logger.info("Dry run. Stop.")
break
if output_config_file is not None:
# Load a trained model and config that you have fine-tuned
output_model_file = os.path.join(output_dir, "pytorch_model_ep{}.bin".format(ep))
if not os.path.exists(output_model_file):
logger.info("Stopping at epoch {} since model file is missing.".format(ep))
break
config = BertConfig(output_config_file)
model = BertForTokenClassification(config, num_labels=len(label_map))
model.load_state_dict(torch.load(output_model_file,
map_location=lambda storage, loc: storage if no_cuda else None))
model.to(device)
if model is None:
raise ValueError('Model required for evaluation.')
model.eval()
y_pred, y_true = model_predict_compare(dataloader, device, label_map, model, dry_run)
lines = ['empty ' + 'XXX ' + v + ' ' + p for yt, yp in zip(y_true, y_pred) for v, p in zip(yt, yp)]
res = conll_eval(lines)
# print(res)
evals = \
pd.concat([pd.DataFrame.from_dict(res['overall']['evals'], orient='index', columns=['ALL']),
pd.DataFrame.from_dict(res['slots']['LOC']['evals'], orient='index', columns=['LOC']),
pd.DataFrame.from_dict(res['slots']['PER']['evals'], orient='index', columns=['PER']),
pd.DataFrame.from_dict(res['slots']['ORG']['evals'], orient='index', columns=['ORG']),
], axis=1).T
stats = \
pd.concat(
[pd.DataFrame.from_dict(res['overall']['stats'], orient='index', columns=['ALL']),
pd.DataFrame.from_dict(res['slots']['LOC']['stats'], orient='index', columns=['LOC']),
pd.DataFrame.from_dict(res['slots']['PER']['stats'], orient='index', columns=['PER']),
pd.DataFrame.from_dict(res['slots']['ORG']['stats'], orient='index', columns=['ORG'])],
axis=1, sort=True).T
evals['epoch'] = ep
stats['epoch'] = ep
results.append(pd.concat([evals.reset_index().set_index(['index', 'epoch']),
stats.reset_index().set_index(['index', 'epoch'])], axis=1))
if output_eval_file is not None:
pd.concat(results).to_pickle(output_eval_file)
results = pd.concat(results)
print(results)
return results
def model_predict_compare(dataloader, device, label_map, model, dry_run=False):
y_true = []
y_pred = []
covered = set()
for input_ids, input_mask, segment_ids, label_ids in tqdm(dataloader, desc="Evaluating"):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask)
logits = torch.argmax(F.log_softmax(logits, dim=2), dim=2)
logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()
input_mask = input_mask.to('cpu').numpy()
for i, mask in enumerate(input_mask):
temp_1 = []
temp_2 = []
for j, m in enumerate(mask):
if j == 0:
continue
if m:
if label_map[label_ids[i][j]] != "X":
temp_1.append(label_map[label_ids[i][j]])
temp_2.append(label_map[logits[i][j]])
else:
temp_1.pop()
temp_2.pop()
y_true.append(temp_1)
y_pred.append(temp_2)
covered = covered.union(set(temp_1))
break
if dry_run:
if 'I-LOC' not in covered:
continue
if 'I-ORG' not in covered:
continue
if 'I-PER' not in covered:
continue
break
return y_pred, y_true
def model_predict(dataloader, device, label_map, model):
y_pred = []
for input_ids, input_mask, segment_ids, label_ids in dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask)
logits = torch.argmax(F.log_softmax(logits, dim=2), dim=2)
logits = logits.detach().cpu().numpy()
input_mask = input_mask.to('cpu').numpy()
for i, mask in enumerate(input_mask):
temp_2 = []
for j, m in enumerate(mask):
if j == 0: # skip first token since its [CLS]
continue
if m:
temp_2.append(label_map[logits[i][j]])
else:
temp_2.pop() # skip last token since its [SEP]
y_pred.append(temp_2)
break
return y_pred
def get_device(local_rank=-1, no_cuda=False):
if local_rank == -1 or no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
return device, n_gpu
def main():
parser = get_arg_parser()
args = parser.parse_args()
do_eval = len(args.dev_sets) > 0 and not args.do_cross_validation
do_train = len(args.train_sets) > 0 and not args.do_cross_validation
device, n_gpu = get_device(args.local_rank, args.no_cuda)
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if not do_train and not do_eval and not args.do_cross_validation:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
task_name = args.task_name.lower()
processors = {"ner": NerProcessor, "wikipedia-ner": WikipediaNerProcessor}
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if args.do_cross_validation:
cross_val_result_file = "cross_validation_results.pkl"
cross_val_result_file = os.path.join(args.output_dir, cross_val_result_file)
sets = set(args.train_sets.split('|')) if args.train_sets is not None else set()
gt = pd.read_pickle(args.gt_file)
gt = gt.loc[gt.dataset.isin(sets)]
k_fold = GroupKFold(n_splits=args.n_splits)
eval_results = list()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
for ep in range(1, int(args.num_train_epochs) + 1):
for sp, (train, test) in enumerate(k_fold.split(X=gt, groups=gt.nsentence)):
tr = gt.iloc[train].copy()
te = gt.iloc[test].copy()
tr['dataset'] = 'TRAIN'
te['dataset'] = 'TEST'
gt_tmp = pd.concat([tr, te])
processor = \
processors[task_name](train_sets='TRAIN', dev_sets='TEST', test_sets='TEST',
gt=gt_tmp, max_seq_length=args.max_seq_length,
tokenizer=tokenizer, data_epochs=args.num_data_epochs,
epoch_size=args.epoch_size)
model, model_config = \
model_train(bert_model=args.bert_model, max_seq_length=args.max_seq_length,
do_lower_case=args.do_lower_case, num_train_epochs=ep,
train_batch_size=args.train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate, weight_decay=args.weight_decay,
loss_scale=args.loss_scale, warmup_proportion=args.warmup_proportion,
processor=processor, device=device, n_gpu=n_gpu, fp16=args.fp16,
cache_dir=args.cache_dir, local_rank=args.local_rank, dry_run=args.dry_run,
no_cuda=args.no_cuda)
label_map = {v: k for k, v in model_config['label_map'].items()}
eval_result =\
model_eval(model=model, label_map=label_map, processor=processor, device=device,
batch_size=args.eval_batch_size, local_rank=args.local_rank,
no_cuda=args.no_cuda, dry_run=args.dry_run).reset_index()
eval_result['split'] = sp
eval_result['epoch'] = ep
eval_results.append(eval_result)
del model # release CUDA memory
pd.concat(eval_results).to_pickle(cross_val_result_file)
if do_train:
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
processor = \
processors[task_name](train_sets=args.train_sets, dev_sets=args.dev_sets, test_sets=args.test_sets,
gt_file=args.gt_file, max_seq_length=args.max_seq_length,
tokenizer=tokenizer, data_epochs=args.num_data_epochs,
epoch_size=args.epoch_size)
model_train(bert_model=args.bert_model, output_dir=args.output_dir, max_seq_length=args.max_seq_length,
do_lower_case=args.do_lower_case, num_train_epochs=args.num_train_epochs,
train_batch_size=args.train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate, weight_decay=args.weight_decay, loss_scale=args.loss_scale,
warmup_proportion=args.warmup_proportion, processor=processor, device=device, n_gpu=n_gpu,
fp16=args.fp16, cache_dir=args.cache_dir, local_rank=args.local_rank, dry_run=args.dry_run,
no_cuda=args.no_cuda)
if do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
model_config = json.load(open(os.path.join(args.output_dir, "model_config.json"), "r"))
label_to_id = model_config['label_map']
label_map = {v: k for k, v in model_config['label_map'].items()}
tokenizer = BertTokenizer.from_pretrained(model_config['bert_model'],
do_lower_case=model_config['do_lower'])
processor = \
processors[task_name](train_sets=None, dev_sets=args.dev_sets, test_sets=args.test_sets,
gt_file=args.gt_file, max_seq_length=model_config['max_seq_length'],
tokenizer=tokenizer, data_epochs=args.num_data_epochs,
epoch_size=args.epoch_size, label_map=label_to_id)
model_eval(label_map=label_map, processor=processor, device=device, num_train_epochs=args.num_train_epochs,
output_dir=args.output_dir, batch_size=args.eval_batch_size, local_rank=args.local_rank,
no_cuda=args.no_cuda, dry_run=args.dry_run)
def get_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--gt_file",
default=None,
type=str,
required=True,
help="The pickle file that contains all NER ground truth as pandas DataFrame."
" Required columns: ['nsentence', 'nword', 'word', 'tag', 'dataset]."
" The selection of training, test and dev set is performed on the 'dataset' column.")
parser.add_argument("--train_sets",
default='',
type=str,
required=False,
help="Specifiy one or more tags from the dataset column in order to mark samples"
" that belong to the training set. Example: 'GERM-EVAL-TRAIN|DE-CONLL-TRAIN'. ")
parser.add_argument("--dev_sets",
default='',
type=str,
required=False,
help="Specifiy one or more tags from the dataset column in order to mark samples"
" that belong to the dev set. Example: 'GERM-EVAL-DEV|DE-CONLL-TESTA'. ")
parser.add_argument("--test_sets",
default='',
type=str,
required=False,
help="Specifiy one or more tags from the dataset column in order to mark samples"
" that belong to the test set. Example: 'GERM-EVAL-TEST|DE-CONLL-TESTB'. ")
parser.add_argument("--bert_model", default=None, type=str, required=False,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=False,
help="The output directory where the model predictions and checkpoints will be written.")
# Other parameters
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=8,
type=int,
help="Total batch size for eval.")
parser.add_argument("--learning_rate",
default=3e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay",
default=0.01,
type=float,
help="Weight decay for Adam.")
parser.add_argument("--num_train_epochs",
default=3.0,
type=float,
help="Total number of training epochs to perform/evaluate.")
parser.add_argument("--num_data_epochs",
default=1.0,
type=float,
help="Re-cycle data after num_data_epochs.")
parser.add_argument("--epoch_size",
default=10000,
type=float,
help="Size of one epoch.")
parser.add_argument("--do_cross_validation",
action='store_true',
help="Do cross-validation.")
parser.add_argument("--n_splits",
default=5,
type=int,
help="Number of folds in cross_validation.")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--dry_run",
action='store_true',
help="Test mode.")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
return parser
if __name__ == "__main__":
main()

View file

@ -0,0 +1 @@
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1,353 @@
import os
from flask import Flask, send_from_directory, redirect, jsonify, request
import pandas as pd
from sqlite3 import Error
import sqlite3
import html
import json
import torch
from somajo import Tokenizer, SentenceSplitter
from qurator.sbb_ner.models.bert import get_device, model_predict
from qurator.sbb_ner.ground_truth.data_processor import NerProcessor, convert_examples_to_features
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import (CONFIG_NAME,
BertConfig,
BertForTokenClassification)
app = Flask(__name__)
app.config.from_json('config.json')
class Digisam:
_conn = None
def __init__(self, data_path):
self._data_path = data_path
@staticmethod
def create_connection(db_file):
try:
conn = sqlite3.connect(db_file, check_same_thread=False)
conn.execute('pragma journal_mode=wal')
return conn
except Error as e:
print(e)
return None
def get(self, ppn):
if Digisam._conn is None:
Digisam._conn = self.create_connection(self._data_path)
df = pd.read_sql_query("select file_name, text from text where ppn=?;", Digisam._conn, params=(ppn,)). \
sort_values('file_name')
return df
class NERPredictor:
def __init__(self, model_dir, batch_size, epoch, max_seq_length=128, local_rank=-1, no_cuda=False):
self._batch_size = batch_size
self._local_rank = local_rank
self._max_seq_length = max_seq_length
self._device, self._n_gpu = get_device(no_cuda=no_cuda)
self._model_config = json.load(open(os.path.join(model_dir, "model_config.json"), "r"))
self._label_to_id = self._model_config['label_map']
self._label_map = {v: k for k, v in self._model_config['label_map'].items()}
self._bert_tokenizer = \
BertTokenizer.from_pretrained(model_dir,
do_lower_case=self._model_config['do_lower'])
output_config_file = os.path.join(model_dir, CONFIG_NAME)
output_model_file = os.path.join(model_dir, "pytorch_model_ep{}.bin".format(epoch))
config = BertConfig(output_config_file)
self._model = BertForTokenClassification(config, num_labels=len(self._label_map))
self._model.load_state_dict(torch.load(output_model_file,
map_location=lambda storage, loc: storage if no_cuda else None))
self._model.to(self._device)
self._model.eval()
return
def classify_text(self, sentences):
examples = NerProcessor.create_examples(sentences, 'test')
features = [convert_examples_to_features(ex, self._label_to_id, self._max_seq_length, self._bert_tokenizer)
for ex in examples]
data_loader = NerProcessor.make_data_loader(None, self._batch_size, self._local_rank, self._label_to_id,
self._max_seq_length, self._bert_tokenizer, features=features,
sequential=True)
prediction_tmp = model_predict(data_loader, self._device, self._label_map, self._model)
prediction = []
for fe, pr in zip(features, prediction_tmp):
prediction.append((fe.tokens[1:-1], pr))
return prediction
class NERTokenizer:
def __init__(self):
self._word_tokenizer = Tokenizer(split_camel_case=True, token_classes=False, extra_info=False)
self._sentence_splitter = SentenceSplitter()
def parse_text(self, text):
tokens = self._word_tokenizer.tokenize_paragraph(text)
sentences_tokenized = self._sentence_splitter.split(tokens)
sentences = []
for sen in sentences_tokenized:
sentences.append((sen, []))
return sentences
class PredictorStore:
def __init__(self):
self._predictor = None
self._model_id = None
def get(self, model_id):
model = next((m for m in app.config['MODELS'] if m['id'] == int(model_id)))
if self._model_id != model_id:
self._predictor = NERPredictor(model_dir=model['model_dir'],
epoch=app.config['EPOCH'],
batch_size=app.config['BATCH_SIZE'],
no_cuda=False if not os.environ.get('USE_CUDA') else
os.environ.get('USE_CUDA').lower() == 'false')
self._model_id = model_id
return self._predictor
digisam = Digisam(app.config['DATA_PATH'])
predictor_store = PredictorStore()
tokenizer = NERTokenizer()
@app.route('/')
def entry():
return redirect("/index.html", code=302)
@app.route('/models')
def get_models():
return jsonify(app.config['MODELS'])
@app.route('/ppnexamples')
def get_ppnexamples():
return jsonify(app.config['PPN_EXAMPLES'])
@app.route('/digisam-fulltext/<ppn>')
def fulltext(ppn):
df = digisam.get(ppn)
if len(df) == 0:
return 'bad request!', 400
text = ''
for row_index, row_data in df.iterrows():
if row_data.text is None:
continue
text += html.escape(str(row_data.text)) + '<br><br><br>'
ret = {'text': text, 'ppn': ppn}
return jsonify(ret)
@app.route('/digisam-tokenized/<ppn>')
def tokenized(ppn):
df = digisam.get(ppn)
if len(df) == 0:
return 'bad request!', 400
text = ''
for row_index, row_data in df.iterrows():
if row_data.text is None:
continue
sentences = tokenizer.parse_text(row_data.text)
for sen, _ in sentences:
text += html.escape(str(sen)) + '<br>'
text += '<br><br><br>'
ret = {'text': text, 'ppn': ppn}
return jsonify(ret)
@app.route('/ner-bert-tokens/<model_id>/<ppn>')
def ner_bert_tokens(model_id, ppn):
df = digisam.get(ppn)
if len(df) == 0:
return 'bad request!', 400
text = ''
for row_index, row_data in df.iterrows():
if row_data.text is None:
continue
sentences = tokenizer.parse_text(row_data.text)
prediction = predictor_store.get(model_id).classify_text(sentences)
for tokens, word_predictions in prediction:
for token, word_pred in zip(tokens, word_predictions):
text += html.escape("{}({})".format(token, word_pred))
text += '<br>'
text += '<br><br><br>'
ret = {'text': text, 'ppn': ppn}
return jsonify(ret)
@app.route('/digisam-ner/<model_id>/<ppn>')
def digisam_ner(model_id, ppn):
df = digisam.get(ppn)
if len(df) == 0:
return 'bad request!', 400
text = ''
for row_index, row_data in df.iterrows():
if row_data.text is None:
continue
sentences = tokenizer.parse_text(row_data.text)
prediction = predictor_store.get(model_id).classify_text(sentences)
for tokens, word_predictions in prediction:
last_prediction = 'O'
for token, word_pred in zip(tokens, word_predictions):
if token == '[UNK]':
continue
if not token.startswith('##'):
text += ' '
token = token[2:] if token.startswith('##') else token
if word_pred != 'X':
last_prediction = word_pred
if last_prediction == 'O':
text += html.escape(token)
elif last_prediction.endswith('PER'):
text += '<font color="red">' + html.escape(token) + '</font>'
elif last_prediction.endswith('LOC'):
text += '<font color="green">' + html.escape(token) + '</font>'
elif last_prediction.endswith('ORG'):
text += '<font color="blue">' + html.escape(token) + '</font>'
text += '<br>'
text += '<br><br><br>'
ret = {'text': text, 'ppn': ppn}
return jsonify(ret)
@app.route('/ner/<model_id>', methods=['GET', 'POST'])
def ner(model_id):
raw_text = request.json['text']
sentences = tokenizer.parse_text(raw_text)
prediction = predictor_store.get(model_id).classify_text(sentences)
output = []
word = None
last_prediction = 'O'
for tokens, word_predictions in prediction:
last_prediction = 'O'
for token, word_pred in zip(tokens, word_predictions):
if token == '[UNK]':
continue
if not token.startswith('##'):
if word is not None:
output.append({'word': word, 'prediction': last_prediction})
word = ''
token = token[2:] if token.startswith('##') else token
word += token
if word_pred != 'X':
last_prediction = word_pred
if word is not None and len(word) > 0:
output.append({'word': word, 'prediction': last_prediction})
return jsonify(output)
@app.route('/<path:path>')
def send_js(path):
return send_from_directory('static', path)

View file

@ -0,0 +1,77 @@
{
"DATA_PATH": "data/digisam/fulltext.sqlite3",
"EPOCH": 7,
"BATCH_SIZE": 256,
"MODELS": [
{
"name": "DC-SBB + CONLL + GERMEVAL",
"id": 1,
"model_dir": "data/konvens2019/build-wd_0.03/bert-all-german-de-finetuned",
"default": true
},
{
"name": "DC-SBB + CONLL + GERMEVAL + SBB",
"id": 2,
"model_dir": "data/konvens2019/build-on-all-german-de-finetuned/bert-sbb-de-finetuned",
"default": false
},
{
"name": "DC-SBB + SBB",
"id": 3,
"model_dir": "data/konvens2019/build-wd_0.03/bert-sbb-de-finetuned",
"default": false
},
{
"name": "CONLL + GERMEVAL",
"id": 4,
"model_dir": "data/konvens2019/build-wd_0.03/bert-all-german-baseline",
"default": false
}
],
"PPN_EXAMPLES": [
{
"ppn": "633609536",
"name": "Der achtzehnte Brumaire des Louis Bonaparte"
},
{
"ppn": "778819027",
"name": "Der zerbrochene Krug"
},
{
"ppn": "71807789X",
"name": "Praktischer Kommentar zu den Gebühren-Taxen für Notare und Rechtsanwälte"
},
{
"ppn": "719153085",
"name": "Der Weltkrieg im Rechenunterricht"
},
{
"ppn": "719961289",
"name": "Das Kriegs-Schaubuch des XVIII. A.K."
},
{
"ppn": "720942748",
"name": "Ein Gebot der Stunde"
},
{
"ppn": "819155217",
"name": "Der Zirkel, 1883"
},
{
"ppn": "847022595",
"name": "Mecklenburgisches Logenblatt"
},
{
"ppn": "756689090",
"name": "Das Buch wunderbarer Erfindungen"
},
{
"ppn": "865468370",
"name": "Carl Robert Lessings Bücher- und Handschriftensammlung"
},
{
"ppn": "818985976",
"name": "\nDie älteste Berliner Zeitung\nOCR\n\nDie älteste Berliner Zeitung : Fragmente der Berliner Wochenzeitung von 1626 aus dem Besitz der Preußischen Staatsbibliothek"
}
]
}

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,77 @@
<!doctype html>
<html lang="en">
<head>
<!-- Required meta tags -->
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<!-- Bootstrap CSS -->
<link rel="stylesheet" href="css/bootstrap.min.css"
integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
<title>NER auf den digitalen Sammlungen</title>
<script src="js/jquery-3.4.1.js"></script>
</head>
<body>
<div class="container-fluid" style="height: 95vh;">
<div class="row" style="margin-top: 5vh">
<div class="col-2">
</div>
<div class="col-10">
<div class="row">
<div class="col-9 text-center">
<h1>NER auf den digitalen Sammlungen</h1>
</div>
<div class="col">
</div>
</div>
<div class="row" style="margin-top: 2vh">
<div class="col-9">
<div class="card">
<div class="card-block">
<form class="mt-3 mb-3" role="form" id="nerform">
<div class="form-group row ml-2">
<label for="task" class="col-sm-2 col-form-label">Task:</label>
<select id="task" class="selectpicker col-md-auto" onchange="task_select()">
<option value="1">OCR-Text aus ALTO Datei</option>
<option value="2">Wort- und Satztokenisierung</option>
<option value="3" selected>Named Entity Recognition</option>
<option value="4">BERT Tokens</option>
</select>
</div>
<div class="form-group row ml-2" id="model_select">
<label for="model" class="col-sm-2 col-form-label">Model:</label>
<select id="model" class="selectpicker col-md-auto">
</select>
</div>
<div class="form-group row ml-2">
<label for="ppn" class="col-sm-2 col-form-label">PPN:</label>
<input id="ppn" list="ppnexamples" class="col-sm-8" type="text"/>
<datalist id="ppnexamples">
</datalist>
<button class="btn btn-primary" type="submit">Go</button>
</div>
</form>
</div>
</div>
</div>
<div class="col">
</div>
</div>
<div class="row mt-5">
<div class="col-9" id="resultregion">
</div>
<div class="col" id="legende">
</div>
</div>
</div>
</div>
</div>
<script src="js/ner.js"></script>
</body>
</html>

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,155 @@
$(document).ready(function(){
$('#nerform').submit(
function(e){
e.preventDefault();
load_ppn();
}
);
$.get( "/models")
.done(
function( data ) {
var tmp="";
$.each(data,
function(index, item){
selected=""
if (item.default) {
selected = "selected"
}
tmp += '<option value="' + item.id + '" ' + selected + ' >' + item.name + '</option>'
});
$('#model').html(tmp);
}
);
$.get( "/ppnexamples")
.done(
function( data ) {
var tmp="";
$.each(data,
function(index, item){
tmp += '<option value="' + item.ppn + '">' + item.name + '</option>'
});
$('#ppnexamples').html(tmp);
}
);
task_select()
});
function task_select() {
var task = $('#task').val();
if (task < 3) {
$('#model_select').hide()
}
else {
$('#model_select').show()
}
$("#resultregion").html("");
$("#legende").html("");
}
function load_ppn() {
var ppn = $('#ppn').val()
var text_region_html =
`<div class="card">
<div class="card-header">
Ergebnis:
</div>
<div class="card-block">
<div id="textregion" style="overflow-y:scroll;height: 65vh;"></div>
</div>
</div>`;
var legende_html =
`<div class="card">
<div class="card-header">
Legende:
<div class="ml-2" >[<font color="red">Person</font>]</div>
<div class="ml-2" >[<font color="green">Ort</font>]</div>
<div class="ml-2" >[<font color="blue">Organisation</font>]</div>
<div class="ml-2" >[keine Named Entity]</div>
</div>
</div>`;
var spinner_html =
`<div class="d-flex justify-content-center">
<div class="spinner-border align-center" role="status">
<span class="sr-only">Loading...</span>
</div>
</div>`;
$("#legende").html("");
var task = $('#task').val();
var model_id = $('#model').val();
console.log("Task: " + task);
if (task == 1) {
$("#resultregion").html(spinner_html);
$.get( "/digisam-fulltext/" + ppn)
.done(function( data ) {
$("#resultregion").html(text_region_html)
$("#textregion").html(data.text)
})
.fail(
function() {
console.log('Failed.');
$("#resultregion").html('Failed.');
});
}
else if (task == 2) {
$("#resultregion").html(spinner_html);
$.get( "/digisam-tokenized/" + ppn,
function( data ) {
$("#resultregion").html(text_region_html)
$("#textregion").html(data.text)
}).fail(
function() {
console.log('Failed.')
$("#resultregion").html('Failed.')
});
}
else if (task == 3) {
$("#resultregion").html(spinner_html);
$.get( "/digisam-ner/" + model_id + "/" + ppn,
function( data ) {
$("#resultregion").html(text_region_html)
$("#textregion").html(data.text)
$("#legende").html(legende_html)
}).fail(
function(a,b,c) {
console.log('Failed.')
$("#resultregion").html('Failed.')
});
}
else if (task == 4) {
$("#resultregion").html(spinner_html);
$.get( "/digisam-ner-bert-tokens/" + model_id + "/" + ppn,
function( data ) {
$("#resultregion").html(text_region_html)
$("#textregion").html(data.text)
}).fail(
function(a,b,c) {
console.log('Failed.')
$("#resultregion").html('Failed.')
});
}
}