You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

436 lines
14 KiB
Python

from __future__ import absolute_import, division, print_function
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset, Dataset)
from torch.utils.data.distributed import DistributedSampler
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_id, tokens):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.tokens = tokens
class WikipediaDataset(Dataset):
"""
"""
def __init__(self, set_file, gt_file, data_epochs, epoch_size,
label_map, tokenizer, max_seq_length,
queue_size=1000, no_entity_fraction=0.0, seed=23,
min_sen_len=10, min_article_len=20):
self._set_file = set_file
self._subset = pd.read_pickle(set_file)
self._gt_file = gt_file
self._data_epochs = data_epochs
self._epoch_size = epoch_size
self._label_map = label_map
self._tokenizer = tokenizer
self._max_seq_length = max_seq_length
self._queue_size = queue_size
self._no_entity_fraction = no_entity_fraction
self._seed = seed
self._min_sen_len = min_sen_len
self._min_article_len = min_article_len
self._queue = None
self._data_sequence = None
self._counter = None
# noinspection PyUnresolvedReferences
self._random_state = np.random.RandomState(seed=self._seed)
self._reset()
return
def _next_sample_should_have_entities(self):
if self._no_entity_fraction <= 0.0:
return True
return int(self._counter) % int(1.0 / self._no_entity_fraction) != 0
def __getitem__(self, index):
del index
if self._counter > self._data_epochs * self._epoch_size:
self._reset()
while True:
# get next random sentence
sen_words, sen_tags = self._queue_next()
if len(sen_words) < self._min_sen_len: # Skip all sentences that are to short.
continue
if self._has_entities(sen_tags):
if not self._next_sample_should_have_entities(): # Skip sample if next sample is supposed to
# be a no-entity sample
continue
else:
if self._next_sample_should_have_entities(): # Skip sample if next sample is supposed to be a entity
# sample
continue
break
sample = InputExample(guid="%s-%s" % (self._set_file, self._counter),
text_a=sen_words, text_b=None, label=sen_tags)
features = convert_examples_to_features(sample, self._label_map, self._max_seq_length, self._tokenizer)
self._counter += 1
return torch.tensor(features.input_ids, dtype=torch.long), \
torch.tensor(features.input_mask, dtype=torch.long), \
torch.tensor(features.segment_ids, dtype=torch.long), \
torch.tensor(features.label_id, dtype=torch.long)
def __len__(self):
return int(self._epoch_size)
def _reset(self):
# print('================= WikipediaDataset:_reset ====================== ')
self._queue = list()
self._data_sequence = self._sequence()
self._counter = 0
# noinspection PyUnresolvedReferences
# self._random_state = np.random.RandomState(seed=self._seed)
for _ in range(0, self._queue_size):
self._queue.append(list())
def _sequence(self):
while True:
for row in pd.read_csv(self._gt_file, chunksize=1, sep=';'):
page_id = row.page_id.iloc[0]
text = row.text.iloc[0]
tags = row.tags.iloc[0]
if page_id not in self._subset.index:
continue
sentences = [(sen_text, sen_tag) for sen_text, sen_tag in zip(json.loads(text), json.loads(tags))]
if len(sentences) < self._min_article_len: # Skip very short articles.
continue
print(page_id)
yield sentences
def _queue_next(self):
nqueue = self._random_state.randint(len(self._queue))
while len(self._queue[nqueue]) <= 0:
self._queue[nqueue] = next(self._data_sequence)
return self._queue[nqueue].pop()
@staticmethod
def _has_entities(sen_tags):
for t in sen_tags:
if t != 'O':
return True
return False
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, batch_size, local_rank):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, batch_size, local_rank):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
def get_evaluation_file(self):
raise NotImplementedError()
class WikipediaNerProcessor(DataProcessor):
def __init__(self, train_sets, dev_sets, test_sets, gt_file, max_seq_length, tokenizer,
data_epochs, epoch_size, **kwargs):
del kwargs
self._max_seq_length = max_seq_length
self._tokenizer = tokenizer
self._train_set_file = train_sets
self._dev_set_file = dev_sets
self._test_set_file = test_sets
self._gt_file = gt_file
self._data_epochs = data_epochs
self._epoch_size = epoch_size
def get_train_examples(self, batch_size, local_rank):
"""See base class."""
return self._make_data_loader(self._train_set_file, batch_size, local_rank)
def get_dev_examples(self, batch_size, local_rank):
"""See base class."""
return self._make_data_loader(self._dev_set_file, batch_size, local_rank)
def get_labels(self):
"""See base class."""
labels = ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG", "X", "[CLS]", "[SEP]"]
return {label: i for i, label in enumerate(labels)}
def get_evaluation_file(self):
dev_set_name = os.path.splitext(os.path.basename(self._dev_set_file))[0]
return "eval_results-{}.pkl".format(dev_set_name)
def _make_data_loader(self, set_file, batch_size, local_rank):
del local_rank
data = WikipediaDataset(set_file=set_file, gt_file=self._gt_file,
data_epochs=self._data_epochs, epoch_size=self._epoch_size,
label_map=self.get_labels(), tokenizer=self._tokenizer,
max_seq_length=self._max_seq_length)
sampler = SequentialSampler(data)
return DataLoader(data, sampler=sampler, batch_size=batch_size)
class NerProcessor(DataProcessor):
def __init__(self, train_sets, dev_sets, test_sets, max_seq_length, tokenizer,
label_map=None, gt=None, gt_file=None, **kwargs):
del kwargs
self._max_seg_length = max_seq_length
self._tokenizer = tokenizer
self._train_sets = set(train_sets.split('|')) if train_sets is not None else set()
self._dev_sets = set(dev_sets.split('|')) if dev_sets is not None else set()
self._test_sets = set(test_sets.split('|')) if test_sets is not None else set()
self._gt = gt
if self._gt is None:
self._gt = pd.read_pickle(gt_file)
self._label_map = label_map
print('TRAIN SETS: ', train_sets)
print('DEV SETS: ', dev_sets)
print('TEST SETS: ', test_sets)
def get_train_examples(self, batch_size, local_rank):
"""See base class."""
return self.make_data_loader(
self.create_examples(self._read_lines(self._train_sets), "train"), batch_size, local_rank,
self.get_labels(), self._max_seg_length, self._tokenizer)
def get_dev_examples(self, batch_size, local_rank):
"""See base class."""
return self.make_data_loader(
self.create_examples(self._read_lines(self._dev_sets), "dev"), batch_size, local_rank,
self.get_labels(), self._max_seg_length, self._tokenizer)
def get_labels(self):
"""See base class."""
if self._label_map is not None:
return self._label_map
gt = self._gt
gt = gt.loc[gt.dataset.isin(self._train_sets.union(self._dev_sets).union(self._test_sets))]
labels = sorted(gt.tag.unique().tolist()) + ["X", "[CLS]", "[SEP]"]
self._label_map = {label: i for i, label in enumerate(labels, 1)}
self._label_map['UNK'] = 0
return self._label_map
def get_evaluation_file(self):
return "eval_results-{}.pkl".format("-".join(sorted(self._dev_sets)))
@staticmethod
def create_examples(lines, set_type):
for i, (sentence, label) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = sentence
text_b = None
label = label
yield InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
@staticmethod
def make_data_loader(examples, batch_size, local_rank, label_map, max_seq_length, tokenizer, features=None,
sequential=False):
if features is None:
features = [convert_examples_to_features(ex, label_map, max_seq_length, tokenizer)
for ex in examples]
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if local_rank == -1:
if sequential:
train_sampler = SequentialSampler(data)
else:
train_sampler = RandomSampler(data)
else:
if sequential:
train_sampler = SequentialSampler(data)
else:
train_sampler = DistributedSampler(data)
return DataLoader(data, sampler=train_sampler, batch_size=batch_size)
def _read_lines(self, sets):
gt = self._gt
gt = gt.loc[gt.dataset.isin(sets)]
data = list()
for i, sent in gt.groupby('nsentence'):
sent = sent.sort_values('nword', ascending=True)
data.append((sent.word.tolist(), sent.tag.tolist()))
return data
def convert_examples_to_features(example, label_map, max_seq_length, tokenizer):
"""
:param example: instance of InputExample
:param label_map:
:param max_seq_length:
:param tokenizer:
:return:
"""
words = example.text_a
word_labels = example.label
tokens = []
labels = []
for i, word in enumerate(words):
token = tokenizer.tokenize(word)
tokens.extend(token)
label_1 = word_labels[i] if i < len(word_labels) else 'O'
for m in range(len(token)):
if m == 0:
labels.append(label_1)
else:
labels.append("X")
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
n_tokens = []
segment_ids = []
label_ids = []
n_tokens.append("[CLS]")
segment_ids.append(0)
label_ids.append(label_map["[CLS]"])
for i, token in enumerate(tokens):
n_tokens.append(token)
segment_ids.append(0)
label_ids.append(label_map[labels[i]])
n_tokens.append("[SEP]")
segment_ids.append(0)
label_ids.append(label_map["[SEP]"])
input_ids = tokenizer.convert_tokens_to_ids(n_tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length
# if ex_index < 5:
# logger.info("*** Example ***")
# logger.info("guid: %s" % example.guid)
# logger.info("tokens: %s" % " ".join(
# [str(x) for x in tokens]))
# logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
# logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
# logger.info(
# "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
# logger.info("label: %s (id = %d)" % (example.label, label_ids))
return InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_ids,
tokens=n_tokens)