From c2ff011010dfd7a43db755613eef0bf1e404a1fc Mon Sep 17 00:00:00 2001 From: Kai Date: Mon, 21 Feb 2022 15:41:27 +0100 Subject: [PATCH] move bert pre-training code to sbb_ner --- qurator/sbb_ner/models/corpus.py | 163 ++++++++ .../models/finetune_on_pregenerated.py | 363 ++++++++++++++++++ .../models/pregenerate_training_data.py | 302 +++++++++++++++ requirements.txt | 1 + setup.py | 6 +- 5 files changed, 834 insertions(+), 1 deletion(-) create mode 100644 qurator/sbb_ner/models/corpus.py create mode 100644 qurator/sbb_ner/models/finetune_on_pregenerated.py create mode 100644 qurator/sbb_ner/models/pregenerate_training_data.py diff --git a/qurator/sbb_ner/models/corpus.py b/qurator/sbb_ner/models/corpus.py new file mode 100644 index 0000000..64f62ac --- /dev/null +++ b/qurator/sbb_ner/models/corpus.py @@ -0,0 +1,163 @@ +import re +import pandas as pd +from tqdm import tqdm as tqdm +import click +import codecs +import os +import sqlite3 + +from qurator.utils.parallel import run as prun + + +class ChunkTask: + + selection = None + + def __init__(self, chunk, min_line_len): + + self._chunk = chunk + self._min_line_len = min_line_len + + def __call__(self, *args, **kwargs): + + return ChunkTask.reformat_chunk(self._chunk, self._min_line_len) + + @staticmethod + def reformat_chunk(chunk, min_line_len): + """ + Process a chunk of documents. + + :param chunk: pandas DataFrame that contains one document per row. + :param min_line_len: Break the document text up in lines that have this minimum length. + :return: One big text where the documents are separated by an empty line. + """ + + text = '' + + for i, r in chunk.iterrows(): + + if type(r.text) != str: + continue + + ppn = r.ppn if str(r.ppn).startswith('PPN') else 'PPN' + r.ppn + + filename = str(r['file name']) + + if not ChunkTask.selection.loc[(ppn, filename)].selected.iloc[0]: + continue + + for se in sentence_split(str(r.text), min_line_len): + + text += se + + text += '\n\n' + + return text + + @staticmethod + def initialize(selection_file): + + ChunkTask.selection = \ + pd.read_pickle(selection_file).\ + reset_index().\ + set_index(['ppn', 'filename']).\ + sort_index() + + +def get_csv_chunks(alto_csv_file, chunksize): + + for ch in tqdm(pd.read_csv(alto_csv_file, chunksize=chunksize)): + + yield ch + + +def get_sqlite_chunks(alto_sqlite_file, chunksize): + + yield pd.DataFrame() + + with sqlite3.connect(alto_sqlite_file) as conn: + + conn.execute('pragma journal_mode=wal') + + total = int(conn.execute('select count(*) from text;').fetchone()[0] / chunksize) + + for ch in tqdm(pd.read_sql('select * from text', conn, chunksize=chunksize), total=total): + + yield ch + + +def get_chunk_tasks(chunks, min_len_len): + + for chunk in chunks: + + if len(chunk) == 0: + continue + + yield ChunkTask(chunk, min_len_len) + + +def sentence_split(s, min_len): + """ + Reformat text of an entire document such that each line has at least length min_len + :param s: str + :param min_len: minimum line length + :return: reformatted text + """ + + parts = s.split(' ') + + se = '' + for p in parts: + + se += ' ' + p + + if len(se) > min_len and len(p) > 2 and re.match(r'.*([^0-9])[.]$', p): + yield se + '\n' + se = '' + + yield se + '\n' + + +@click.command() +@click.argument('fulltext-file', type=click.Path(exists=True), required=True, nargs=1) +@click.argument('selection-file', type=click.Path(exists=True), required=True, nargs=1) +@click.argument('corpus-file', type=click.Path(), required=True, nargs=1) +@click.option('--chunksize', default=10**4, help="Process the corpus in chunks of . default:10**4") +@click.option('--processes', default=6, help="Number of parallel processes. default: 6") +@click.option('--min-line-len', default=80, help="Lower bound of line length in output file. default:80") +def collect(fulltext_file, selection_file, corpus_file, chunksize, processes, min_line_len): + """ + Reads the fulltext from a CSV or SQLITE3 file (see also altotool) and write it to one big text file. + + FULLTEXT_FILE: The CSV or SQLITE3 file to read from. + + SELECTION_FILE: Consider only a subset of all pages that is defined by the DataFrame + that is stored in . + + CORPUS_FILE: The output file that can be used by bert-pregenerate-trainingdata. + """ + os.makedirs(os.path.dirname(corpus_file), exist_ok=True) + + print('Open {}.'.format(corpus_file)) + corpus_fh = codecs.open(corpus_file, 'w+', 'utf-8') + corpus_fh.write(u'\ufeff') + + if fulltext_file.endswith('.csv'): + chunks = get_csv_chunks(fulltext_file, chunksize) + elif fulltext_file.endswith('.sqlite3'): + chunks = get_sqlite_chunks(fulltext_file, chunksize) + else: + raise RuntimeError('Unsupported input file format.') + + for text in prun(get_chunk_tasks(chunks, min_line_len), processes=processes, initializer=ChunkTask.initialize, + initargs=(selection_file,)): + + corpus_fh.write(text) + + corpus_fh.close() + + return + + +if __name__ == '__main__': + main() diff --git a/qurator/sbb_ner/models/finetune_on_pregenerated.py b/qurator/sbb_ner/models/finetune_on_pregenerated.py new file mode 100644 index 0000000..d271a3c --- /dev/null +++ b/qurator/sbb_ner/models/finetune_on_pregenerated.py @@ -0,0 +1,363 @@ +from argparse import ArgumentParser +from pathlib import Path +import torch +import logging +import json +import random +import numpy as np +from collections import namedtuple +from tempfile import TemporaryDirectory + +from torch.utils.data import DataLoader, Dataset, RandomSampler +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from pytorch_pretrained_bert.modeling import BertForPreTraining +from pytorch_pretrained_bert.tokenization import BertTokenizer +from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule + +InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next") + +log_format = '%(asctime)-10s: %(message)s' +logging.basicConfig(level=logging.INFO, format=log_format) + + +def convert_example_to_features(example, tokenizer, max_seq_length): + tokens = example["tokens"] + segment_ids = example["segment_ids"] + is_random_next = example["is_random_next"] + masked_lm_positions = example["masked_lm_positions"] + masked_lm_labels = example["masked_lm_labels"] + + assert len(tokens) == len(segment_ids) <= max_seq_length # The preprocessed data should be already truncated + input_ids = tokenizer.convert_tokens_to_ids(tokens) + masked_label_ids = tokenizer.convert_tokens_to_ids(masked_lm_labels) + + input_array = np.zeros(max_seq_length, dtype=np.int) + input_array[:len(input_ids)] = input_ids + + mask_array = np.zeros(max_seq_length, dtype=np.bool) + mask_array[:len(input_ids)] = 1 + + segment_array = np.zeros(max_seq_length, dtype=np.bool) + segment_array[:len(segment_ids)] = segment_ids + + lm_label_array = np.full(max_seq_length, dtype=np.int, fill_value=-1) + lm_label_array[masked_lm_positions] = masked_label_ids + + features = InputFeatures(input_ids=input_array, + input_mask=mask_array, + segment_ids=segment_array, + lm_label_ids=lm_label_array, + is_next=is_random_next) + return features + + +class PregeneratedDataset(Dataset): + def __init__(self, training_path, epoch, tokenizer, num_data_epochs, reduce_memory=False, prefix=None): + self.vocab = tokenizer.vocab + self.tokenizer = tokenizer + self.epoch = epoch + self.data_epoch = epoch % num_data_epochs + data_file = training_path / f"epoch_{self.data_epoch}.json" + metrics_file = training_path / f"epoch_{self.data_epoch}_metrics.json" + assert data_file.is_file() and metrics_file.is_file() + metrics = json.loads(metrics_file.read_text()) + num_samples = metrics['num_training_examples'] + seq_len = metrics['max_seq_len'] + self.temp_dir = None + self.working_dir = None + if reduce_memory: + self.temp_dir = TemporaryDirectory(prefix=prefix) + self.working_dir = Path(self.temp_dir.name) + input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap', + mode='w+', dtype=np.int32, shape=(num_samples, seq_len)) + input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap', + shape=(num_samples, seq_len), mode='w+', dtype=np.bool) + segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap', + shape=(num_samples, seq_len), mode='w+', dtype=np.bool) + lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap', + shape=(num_samples, seq_len), mode='w+', dtype=np.int32) + lm_label_ids[:] = -1 + is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap', + shape=(num_samples,), mode='w+', dtype=np.bool) + else: + input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32) + input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool) + segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool) + lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1) + is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool) + logging.info(f"Loading training examples for epoch {epoch}") + with data_file.open() as f: + for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")): + line = line.strip() + example = json.loads(line) + features = convert_example_to_features(example, tokenizer, seq_len) + input_ids[i] = features.input_ids + segment_ids[i] = features.segment_ids + input_masks[i] = features.input_mask + lm_label_ids[i] = features.lm_label_ids + is_nexts[i] = features.is_next + assert i == num_samples - 1 # Assert that the sample count metric was true + logging.info("Loading complete!") + self.num_samples = num_samples + self.seq_len = seq_len + self.input_ids = input_ids + self.input_masks = input_masks + self.segment_ids = segment_ids + self.lm_label_ids = lm_label_ids + self.is_nexts = is_nexts + + def __len__(self): + return self.num_samples + + def __getitem__(self, item): + return (torch.tensor(self.input_ids[item].astype(np.int64)), + torch.tensor(self.input_masks[item].astype(np.int64)), + torch.tensor(self.segment_ids[item].astype(np.int64)), + torch.tensor(self.lm_label_ids[item].astype(np.int64)), + torch.tensor(self.is_nexts[item].astype(np.int64))) + + +def main(): + parser = ArgumentParser() + parser.add_argument('--pregenerated_data', type=Path, required=True) + parser.add_argument('--output_dir', type=Path, required=True) + parser.add_argument("--bert_model", type=str, required=True, help="Directory where the Bert pre-trained model can be found " + "or Bert pre-trained model selected in the list: bert-base-uncased, " + "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") + parser.add_argument("--do_lower_case", action="store_true") + parser.add_argument("--reduce_memory", action="store_true", + help="Store training data as on-disc memmaps to massively reduce memory usage") + + parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for") + parser.add_argument("--local_rank", + type=int, + default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument("--no_cuda", + action='store_true', + help="Whether not to use CUDA when available") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--train_batch_size", + default=32, + type=int, + help="Total batch size for training.") + parser.add_argument("--save_interval", + default=20000, + type=int, + help="Save model every save_interval training steps.") + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead of 32-bit") + parser.add_argument('--loss_scale', + type=float, default=0, + help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" + "0 (default value): dynamic loss scaling.\n" + "Positive power of 2: static loss scaling value.\n") + parser.add_argument("--warmup_proportion", + default=0.1, + type=float, + help="Proportion of training to perform linear learning rate warmup for. " + "E.g., 0.1 = 10%% of training.") + parser.add_argument("--learning_rate", + default=3e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument('--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument('--temp_prefix', + type=str, + default=None, + help="where to store temporary data") + + args = parser.parse_args() + + assert args.pregenerated_data.is_dir(), \ + "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!" + + samples_per_epoch = [] + for i in range(args.epochs): + epoch_file = args.pregenerated_data / f"epoch_{i}.json" + metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json" + if epoch_file.is_file() and metrics_file.is_file(): + metrics = json.loads(metrics_file.read_text()) + samples_per_epoch.append(metrics['num_training_examples']) + else: + if i == 0: + exit("No training data was found!") + print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).") + print("This script will loop over the available data, but training diversity may be negatively impacted.") + num_data_epochs = i + break + else: + num_data_epochs = args.epochs + + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + n_gpu = torch.cuda.device_count() + else: + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + n_gpu = 1 + # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.distributed.init_process_group(backend='nccl') + logging.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( + device, n_gpu, bool(args.local_rank != -1), args.fp16)) + + if args.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( + args.gradient_accumulation_steps)) + + args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + if args.output_dir.is_dir() and list(args.output_dir.iterdir()): + logging.warning(f"Output directory ({args.output_dir}) already exists and is not empty!") + args.output_dir.mkdir(parents=True, exist_ok=True) + + tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) + + total_train_examples = 0 + for i in range(args.epochs): + # The modulo takes into account the fact that we may loop over limited epochs of data + total_train_examples += samples_per_epoch[i % len(samples_per_epoch)] + + num_train_optimization_steps = int( + total_train_examples / args.train_batch_size / args.gradient_accumulation_steps) + + if args.local_rank != -1: + num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() + + # Prepare model + model = BertForPreTraining.from_pretrained(args.bert_model) + if args.fp16: + model.half() + model.to(device) + if args.local_rank != -1: + try: + from apex.parallel import DistributedDataParallel as DDP + except ImportError: + raise ImportError( + "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + model = DDP(model) + elif n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Prepare optimizer + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], + 'weight_decay': 0.01}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + + if args.fp16: + try: + from apex.optimizers import FP16_Optimizer + from apex.optimizers import FusedAdam + except ImportError: + raise ImportError( + "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + + optimizer = FusedAdam(optimizer_grouped_parameters, + lr=args.learning_rate, + bias_correction=False, + max_grad_norm=1.0) + if args.loss_scale == 0: + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + else: + optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) + warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, + t_total=num_train_optimization_steps) + else: + optimizer = BertAdam(optimizer_grouped_parameters, + lr=args.learning_rate, + warmup=args.warmup_proportion, + t_total=num_train_optimization_steps) + + global_step = 0 + logging.info("***** Running training *****") + logging.info(f" Num examples = {total_train_examples}") + logging.info(" Batch size = %d", args.train_batch_size) + logging.info(" Num steps = %d", num_train_optimization_steps) + model.train() + + def save_model(): + + logging.info("** ** * Saving fine-tuned model ** ** * ") + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = args.output_dir / "pytorch_model.bin" + torch.save(model_to_save.state_dict(), str(output_model_file)) + + for epoch in range(args.epochs): + epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, + num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory, + prefix=args.temp_prefix) + if args.local_rank == -1: + train_sampler = RandomSampler(epoch_dataset) + else: + train_sampler = DistributedSampler(epoch_dataset) + train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + tr_loss = 0 + nb_tr_examples, nb_tr_steps = 0, 0 + with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar: + for step, batch in enumerate(train_dataloader): + + batch = tuple(t.to(device) for t in batch) + input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch + loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) + if n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu. + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + if args.fp16: + optimizer.backward(loss) + else: + loss.backward() + tr_loss += loss.item() + nb_tr_examples += input_ids.size(0) + nb_tr_steps += 1 + pbar.update(1) + mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps + pbar.set_postfix_str(f"Loss: {mean_loss:.5f}") + + if step % args.save_interval == 0: + save_model() + + if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + # modify learning rate with special warm up BERT uses + # if args.fp16 is False, BertAdam is used that handles this automatically + lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + + optimizer.step() + optimizer.zero_grad() + + global_step += 1 + + # Save a trained model + # logging.info("** ** * Saving fine-tuned model ** ** * ") + # model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + # output_model_file = args.output_dir / "pytorch_model.bin" + # torch.save(model_to_save.state_dict(), str(output_model_file)) + + save_model() + + +if __name__ == '__main__': + main() diff --git a/qurator/sbb_ner/models/pregenerate_training_data.py b/qurator/sbb_ner/models/pregenerate_training_data.py new file mode 100644 index 0000000..c806fce --- /dev/null +++ b/qurator/sbb_ner/models/pregenerate_training_data.py @@ -0,0 +1,302 @@ +from argparse import ArgumentParser +from pathlib import Path +from tqdm import tqdm, trange +from tempfile import TemporaryDirectory +import shelve + +from random import random, randrange, randint, shuffle, choice, sample +from pytorch_pretrained_bert.tokenization import BertTokenizer +import numpy as np +import json + + +class DocumentDatabase: + def __init__(self, reduce_memory=False): + if reduce_memory: + self.temp_dir = TemporaryDirectory() + self.working_dir = Path(self.temp_dir.name) + self.document_shelf_filepath = self.working_dir / 'shelf.db' + self.document_shelf = shelve.open(str(self.document_shelf_filepath), + flag='n', protocol=-1) + self.documents = None + else: + self.documents = [] + self.document_shelf = None + self.document_shelf_filepath = None + self.temp_dir = None + self.doc_lengths = [] + self.doc_cumsum = None + self.cumsum_max = None + self.reduce_memory = reduce_memory + + def add_document(self, document): + if not document: + return + if self.reduce_memory: + current_idx = len(self.doc_lengths) + self.document_shelf[str(current_idx)] = document + else: + self.documents.append(document) + self.doc_lengths.append(len(document)) + + def _precalculate_doc_weights(self): + self.doc_cumsum = np.cumsum(self.doc_lengths) + self.cumsum_max = self.doc_cumsum[-1] + + def sample_doc(self, current_idx, sentence_weighted=True): + # Uses the current iteration counter to ensure we don't sample the same doc twice + if sentence_weighted: + # With sentence weighting, we sample docs proportionally to their sentence length + if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths): + self._precalculate_doc_weights() + rand_start = self.doc_cumsum[current_idx] + rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] + sentence_index = randrange(rand_start, rand_end) % self.cumsum_max + sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') + else: + # If we don't use sentence weighting, then every doc has an equal chance to be chosen + sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths) + assert sampled_doc_index != current_idx + if self.reduce_memory: + return self.document_shelf[str(sampled_doc_index)] + else: + return self.documents[sampled_doc_index] + + def __len__(self): + return len(self.doc_lengths) + + def __getitem__(self, item): + if self.reduce_memory: + return self.document_shelf[str(item)] + else: + return self.documents[item] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + if self.document_shelf is not None: + self.document_shelf.close() + if self.temp_dir is not None: + self.temp_dir.cleanup() + + +def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): + """Truncates a pair of sequences to a maximum sequence length. Lifted from Google's BERT repo.""" + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_num_tokens: + break + + trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b + assert len(trunc_tokens) >= 1 + + # We want to sometimes truncate from the front and sometimes from the + # back to add more randomness and avoid biases. + if random() < 0.5: + del trunc_tokens[0] + else: + trunc_tokens.pop() + + +def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list): + """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but + with several refactors to clean it up and remove a lot of unnecessary variables.""" + cand_indices = [] + for (i, token) in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + cand_indices.append(i) + + num_to_mask = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + shuffle(cand_indices) + mask_indices = sorted(sample(cand_indices, num_to_mask)) + masked_token_labels = [] + for index in mask_indices: + # 80% of the time, replace with [MASK] + if random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = choice(vocab_list) + masked_token_labels.append(tokens[index]) + # Once we've saved the true label for that token, we can overwrite it with the masked version + tokens[index] = masked_token + + return tokens, mask_indices, masked_token_labels + + +def create_instances_from_document( + doc_database, doc_idx, max_seq_length, short_seq_prob, + masked_lm_prob, max_predictions_per_seq, vocab_list): + """This code is mostly a duplicate of the equivalent function from Google BERT's repo. + However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function. + Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence + (rather than each document) has an equal chance of being sampled as a false example for the NextSentence task.""" + document = doc_database[doc_idx] + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 3 + + # We *usually* want to fill up the entire sequence since we are padding + # to `max_seq_length` anyways, so short sequences are generally wasted + # computation. However, we *sometimes* + # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter + # sequences to minimize the mismatch between pre-training and fine-tuning. + # The `target_seq_length` is just a rough target however, whereas + # `max_seq_length` is a hard limit. + target_seq_length = max_num_tokens + if random() < short_seq_prob: + target_seq_length = randint(2, max_num_tokens) + + # We DON'T just concatenate all of the tokens from a document into a long + # sequence and choose an arbitrary split point because this would make the + # next sentence prediction task too easy. Instead, we split the input into + # segments "A" and "B" based on the actual "sentences" provided by the user + # input. + instances = [] + current_chunk = [] + current_length = 0 + i = 0 + while i < len(document): + segment = document[i] + current_chunk.append(segment) + current_length += len(segment) + if i == len(document) - 1 or current_length >= target_seq_length: + if current_chunk: + # `a_end` is how many segments from `current_chunk` go into the `A` + # (first) sentence. + a_end = 1 + if len(current_chunk) >= 2: + a_end = randrange(1, len(current_chunk)) + + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + + tokens_b = [] + + # Random next + if len(current_chunk) == 1 or random() < 0.5: + is_random_next = True + target_b_length = target_seq_length - len(tokens_a) + + # Sample a random document, with longer docs being sampled more frequently + random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) + + random_start = randrange(0, len(random_document)) + for j in range(random_start, len(random_document)): + tokens_b.extend(random_document[j]) + if len(tokens_b) >= target_b_length: + break + # We didn't actually use these segments so we "put them back" so + # they don't go to waste. + num_unused_segments = len(current_chunk) - a_end + i -= num_unused_segments + # Actual next + else: + is_random_next = False + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) + + assert len(tokens_a) >= 1 + assert len(tokens_b) >= 1 + + tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] + # The segment IDs are 0 for the [CLS] token, the A tokens and the first [SEP] + # They are 1 for the B tokens and the final [SEP] + segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] + + tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions( + tokens, masked_lm_prob, max_predictions_per_seq, vocab_list) + + instance = { + "tokens": tokens, + "segment_ids": segment_ids, + "is_random_next": is_random_next, + "masked_lm_positions": masked_lm_positions, + "masked_lm_labels": masked_lm_labels} + instances.append(instance) + current_chunk = [] + current_length = 0 + i += 1 + + return instances + + +def main(): + parser = ArgumentParser() + parser.add_argument('--train_corpus', type=Path, required=True) + parser.add_argument("--output_dir", type=Path, required=True) + parser.add_argument("--bert_model", type=str, required=True) # , +# choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", +# "bert-base-multilingual", "bert-base-chinese"]) + parser.add_argument("--do_lower_case", action="store_true") + + parser.add_argument("--reduce_memory", action="store_true", + help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") + + parser.add_argument("--epochs_to_generate", type=int, default=3, + help="Number of epochs of data to pregenerate") + parser.add_argument("--max_seq_len", type=int, default=128) + parser.add_argument("--short_seq_prob", type=float, default=0.1, + help="Probability of making a short sentence as a training example") + parser.add_argument("--masked_lm_prob", type=float, default=0.15, + help="Probability of masking each token for the LM task") + parser.add_argument("--max_predictions_per_seq", type=int, default=20, + help="Maximum number of tokens to mask in each sequence") + + args = parser.parse_args() + + tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) + vocab_list = list(tokenizer.vocab.keys()) + with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: + with args.train_corpus.open() as f: + doc = [] + for line in tqdm(f, desc="Loading Dataset", unit=" lines"): + line = line.strip() + if line == "": + docs.add_document(doc) + doc = [] + else: + tokens = tokenizer.tokenize(line) + doc.append(tokens) + if doc: + docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added + if len(docs) <= 1: + exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to " + "ensure that random NextSentences are not sampled from the same document. Please add blank lines to " + "indicate breaks between documents in your input file. If your dataset does not contain multiple " + "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, " + "sections or paragraphs.") + + args.output_dir.mkdir(exist_ok=True) + for epoch in trange(args.epochs_to_generate, desc="Epoch"): + epoch_filename = args.output_dir / f"epoch_{epoch}.json" + num_instances = 0 + with epoch_filename.open('w') as epoch_file: + for doc_idx in trange(len(docs), desc="Document"): + doc_instances = create_instances_from_document( + docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, + masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, + vocab_list=vocab_list) + doc_instances = [json.dumps(instance) for instance in doc_instances] + for instance in doc_instances: + epoch_file.write(instance + '\n') + num_instances += 1 + metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" + with metrics_file.open('w') as metrics_file: + metrics = { + "num_training_examples": num_instances, + "max_seq_len": args.max_seq_len + } + metrics_file.write(json.dumps(metrics)) + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt index 2bd4239..ce56929 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ flask Flask-Caching gunicorn somajo +qurator-sbb-tools \ No newline at end of file diff --git a/setup.py b/setup.py index d4c4ef0..63c24a2 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,11 @@ setup( "compile_conll=qurator.sbb_ner.ground_truth.conll:main", "compile_wikiner=qurator.sbb_ner.ground_truth.wikiner:main", "join-gt=qurator.sbb_ner.ground_truth.join_gt:main", - "bert-ner=qurator.sbb_ner.models.bert:main" + "bert-ner=qurator.sbb_ner.models.bert:main", + + "collectcorpus=qurator.sbb_ner.models.corpus:collect", + "bert-pregenerate-trainingdata=qurator.sbb_ner.models.pregenerate_training_data:main", + "bert-finetune=qurator.sbb_ner.models.finetune_on_pregenerated:main" ] }, python_requires='>=3.6.0',