diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 62d8e51..da2cbdb 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -3,6 +3,7 @@ import sys import io import json import click +from typing import Optional from tqdm import tqdm import requests @@ -397,7 +398,7 @@ def run(_config, f1_threshold_classification=None, classification_classes_name=None, ## if task=cnn-rnn-ocr - characters_txt_file=None, + characters_txt_file: Optional[str]=None, color_padding_rotation=False, thetha_padd=None, bin_deg=False, @@ -698,6 +699,89 @@ def run(_config, callbacks=callbacks, initial_epoch=index_start) + elif task=="transformer-ocr": + import torch + from torch.utils.data import Dataset as TorchDataset + from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator + dir_img, dir_lab = get_dirs_or_files(dir_train) + + if continue_training: + model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model) + else: + model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") + + processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") + + + # Create a DataLoader + class TransformerOCRTorchDataset(TorchDataset): + """ + Wraps preprocess_imgs in a format consumable by torch + """ + def __init__(self, config, dir_img, dir_lab, char_to_num): + self.samples = list( + preprocess_imgs( + config, + dir_img, + dir_lab, + char_to_num=char_to_num + ) + ) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + image, label = self.samples[idx] + + return { + "image": torch.as_tensor(image, dtype=torch.float32), + "label": torch.as_tensor(label, dtype=torch.long), + } + assert characters_txt_file + with open(characters_txt_file, 'r') as char_txt_f: + characters = json.load(char_txt_f) + char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) + dataset = TransformerOCRTorchDataset(_config, dir_img, dir_lab, char_to_num) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=1) + train_dataset = data_loader.dataset + + # set special tokens used for creating the decoder_input_ids from the labels + model.config.decoder_start_token_id = processor.tokenizer.cls_token_id + model.config.pad_token_id = processor.tokenizer.pad_token_id + # make sure vocab size is set correctly + model.config.vocab_size = model.config.decoder.vocab_size + + # set beam search parameters + model.config.eos_token_id = processor.tokenizer.sep_token_id + model.config.max_length = max_len + model.config.early_stopping = True + model.config.no_repeat_ngram_size = 3 + model.config.length_penalty = 2.0 + model.config.num_beams = 4 + + + training_args = Seq2SeqTrainingArguments( + predict_with_generate=True, + num_train_epochs=n_epochs, + learning_rate=learning_rate, + per_device_train_batch_size=n_batch, + fp16=True, + output_dir=dir_output, + logging_steps=2, + save_steps=save_interval, + ) + + # instantiate trainer + trainer = Seq2SeqTrainer( + model=model, + tokenizer=processor.feature_extractor, + args=training_args, + train_dataset=train_dataset, + data_collator=default_data_collator, + ) + trainer.train() + elif task=='classification': if continue_training: model = load_model(dir_of_start_model, compile=False) diff --git a/train/config_params_trocr.json b/train/config_params_trocr.json new file mode 100644 index 0000000..34c6376 --- /dev/null +++ b/train/config_params_trocr.json @@ -0,0 +1,82 @@ +{ + "backbone_type" : "transformer", + "task": "transformer-ocr", + "n_classes" : 2, + "max_len": 192, + "n_epochs" : 1, + "input_height" : 32, + "input_width" : 512, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-5, + "save_interval": 1500, + "patches" : false, + "pretraining" : false, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : true, + "scaling" : false, + "adding_rgb_background": true, + "adding_rgb_foreground": true, + "add_red_textlines": true, + "white_noise_strap": true, + "textline_right_in_depth": true, + "textline_left_in_depth": true, + "textline_up_in_depth": true, + "textline_down_in_depth": true, + "textline_right_in_depth_bin": true, + "textline_left_in_depth_bin": true, + "textline_up_in_depth_bin": true, + "textline_down_in_depth_bin": true, + "bin_deg": true, + "textline_skewing": true, + "textline_skewing_bin": true, + "channels_shuffling": true, + "degrading": true, + "brightening": true, + "binarization" : true, + "pepper_aug": true, + "pepper_bin_aug": true, + "image_inversion": true, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "color_padding_rotation": true, + "padding_white": true, + "rotation_not_90": true, + "transformer_num_patches_xy": [56, 56], + "transformer_patchsize_x": 4, + "transformer_patchsize_y": 4, + "transformer_projection_dim": 64, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 1, + "transformer_num_heads": 1, + "transformer_cnn_first": false, + "blur_k" : ["blur","gauss","median"], + "padd_colors" : ["white", "black"], + "scales" : [0.6, 0.7, 0.8, 0.9], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "pepper_indexes": [0.01, 0.005], + "skewing_amplitudes" : [5, 8], + "flip_index" : [0, 1, -1], + "shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]], + "thetha" : [0.1, 0.2, -0.1, -0.2], + "thetha_padd": [-0.6, -1, -1.4, -1.8, 0.6, 1, 1.4, 1.8], + "white_padds" : [0.1, 0.3, 0.5, 0.7, 0.9], + "number_of_backgrounds_per_image": 2, + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, + "dir_train": "/home/vahid/extracted_lines/1919_bin/train", + "dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new", + "dir_output": "/home/vahid/extracted_lines/1919_bin/output", + "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", + "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", + "dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin" + +} diff --git a/train/requirements.txt b/train/requirements.txt index 090bc50..03994f8 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -8,3 +8,6 @@ tensorflow-addons # for connected_components, depublished and only compatible wi tensorflow < 2.16 # for tensorflow-addons, so only needed in training tf_data < 2.16 # for tensorflow-addons, so only needed in training protobuf < 5 # for tensorflow-addons, so only needed in training +torch +transformers <= 4.30.2 ; python_version < '3.10' +transformers >= 5 ; python_version >= '3.10'