mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
integrating transformer ocr
This commit is contained in:
parent
e9839a8b54
commit
7f86a55ccb
3 changed files with 170 additions and 1 deletions
|
|
@ -3,6 +3,7 @@ import sys
|
|||
import io
|
||||
import json
|
||||
import click
|
||||
from typing import Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
|
|
@ -397,7 +398,7 @@ def run(_config,
|
|||
f1_threshold_classification=None,
|
||||
classification_classes_name=None,
|
||||
## if task=cnn-rnn-ocr
|
||||
characters_txt_file=None,
|
||||
characters_txt_file: Optional[str]=None,
|
||||
color_padding_rotation=False,
|
||||
thetha_padd=None,
|
||||
bin_deg=False,
|
||||
|
|
@ -698,6 +699,89 @@ def run(_config,
|
|||
callbacks=callbacks,
|
||||
initial_epoch=index_start)
|
||||
|
||||
elif task=="transformer-ocr":
|
||||
import torch
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
|
||||
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
||||
|
||||
if continue_training:
|
||||
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
|
||||
else:
|
||||
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
|
||||
|
||||
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||
|
||||
|
||||
# Create a DataLoader
|
||||
class TransformerOCRTorchDataset(TorchDataset):
|
||||
"""
|
||||
Wraps preprocess_imgs in a format consumable by torch
|
||||
"""
|
||||
def __init__(self, config, dir_img, dir_lab, char_to_num):
|
||||
self.samples = list(
|
||||
preprocess_imgs(
|
||||
config,
|
||||
dir_img,
|
||||
dir_lab,
|
||||
char_to_num=char_to_num
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image, label = self.samples[idx]
|
||||
|
||||
return {
|
||||
"image": torch.as_tensor(image, dtype=torch.float32),
|
||||
"label": torch.as_tensor(label, dtype=torch.long),
|
||||
}
|
||||
assert characters_txt_file
|
||||
with open(characters_txt_file, 'r') as char_txt_f:
|
||||
characters = json.load(char_txt_f)
|
||||
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
||||
dataset = TransformerOCRTorchDataset(_config, dir_img, dir_lab, char_to_num)
|
||||
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
||||
train_dataset = data_loader.dataset
|
||||
|
||||
# set special tokens used for creating the decoder_input_ids from the labels
|
||||
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
# make sure vocab size is set correctly
|
||||
model.config.vocab_size = model.config.decoder.vocab_size
|
||||
|
||||
# set beam search parameters
|
||||
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
||||
model.config.max_length = max_len
|
||||
model.config.early_stopping = True
|
||||
model.config.no_repeat_ngram_size = 3
|
||||
model.config.length_penalty = 2.0
|
||||
model.config.num_beams = 4
|
||||
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
predict_with_generate=True,
|
||||
num_train_epochs=n_epochs,
|
||||
learning_rate=learning_rate,
|
||||
per_device_train_batch_size=n_batch,
|
||||
fp16=True,
|
||||
output_dir=dir_output,
|
||||
logging_steps=2,
|
||||
save_steps=save_interval,
|
||||
)
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
tokenizer=processor.feature_extractor,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=default_data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
elif task=='classification':
|
||||
if continue_training:
|
||||
model = load_model(dir_of_start_model, compile=False)
|
||||
|
|
|
|||
82
train/config_params_trocr.json
Normal file
82
train/config_params_trocr.json
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
{
|
||||
"backbone_type" : "transformer",
|
||||
"task": "transformer-ocr",
|
||||
"n_classes" : 2,
|
||||
"max_len": 192,
|
||||
"n_epochs" : 1,
|
||||
"input_height" : 32,
|
||||
"input_width" : 512,
|
||||
"weight_decay" : 1e-6,
|
||||
"n_batch" : 1,
|
||||
"learning_rate": 1e-5,
|
||||
"save_interval": 1500,
|
||||
"patches" : false,
|
||||
"pretraining" : false,
|
||||
"augmentation" : true,
|
||||
"flip_aug" : false,
|
||||
"blur_aug" : true,
|
||||
"scaling" : false,
|
||||
"adding_rgb_background": true,
|
||||
"adding_rgb_foreground": true,
|
||||
"add_red_textlines": true,
|
||||
"white_noise_strap": true,
|
||||
"textline_right_in_depth": true,
|
||||
"textline_left_in_depth": true,
|
||||
"textline_up_in_depth": true,
|
||||
"textline_down_in_depth": true,
|
||||
"textline_right_in_depth_bin": true,
|
||||
"textline_left_in_depth_bin": true,
|
||||
"textline_up_in_depth_bin": true,
|
||||
"textline_down_in_depth_bin": true,
|
||||
"bin_deg": true,
|
||||
"textline_skewing": true,
|
||||
"textline_skewing_bin": true,
|
||||
"channels_shuffling": true,
|
||||
"degrading": true,
|
||||
"brightening": true,
|
||||
"binarization" : true,
|
||||
"pepper_aug": true,
|
||||
"pepper_bin_aug": true,
|
||||
"image_inversion": true,
|
||||
"scaling_bluring" : false,
|
||||
"scaling_binarization" : false,
|
||||
"scaling_flip" : false,
|
||||
"rotation": false,
|
||||
"color_padding_rotation": true,
|
||||
"padding_white": true,
|
||||
"rotation_not_90": true,
|
||||
"transformer_num_patches_xy": [56, 56],
|
||||
"transformer_patchsize_x": 4,
|
||||
"transformer_patchsize_y": 4,
|
||||
"transformer_projection_dim": 64,
|
||||
"transformer_mlp_head_units": [128, 64],
|
||||
"transformer_layers": 1,
|
||||
"transformer_num_heads": 1,
|
||||
"transformer_cnn_first": false,
|
||||
"blur_k" : ["blur","gauss","median"],
|
||||
"padd_colors" : ["white", "black"],
|
||||
"scales" : [0.6, 0.7, 0.8, 0.9],
|
||||
"brightness" : [1.3, 1.5, 1.7, 2],
|
||||
"degrade_scales" : [0.2, 0.4],
|
||||
"pepper_indexes": [0.01, 0.005],
|
||||
"skewing_amplitudes" : [5, 8],
|
||||
"flip_index" : [0, 1, -1],
|
||||
"shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]],
|
||||
"thetha" : [0.1, 0.2, -0.1, -0.2],
|
||||
"thetha_padd": [-0.6, -1, -1.4, -1.8, 0.6, 1, 1.4, 1.8],
|
||||
"white_padds" : [0.1, 0.3, 0.5, 0.7, 0.9],
|
||||
"number_of_backgrounds_per_image": 2,
|
||||
"continue_training": false,
|
||||
"index_start" : 0,
|
||||
"dir_of_start_model" : " ",
|
||||
"weighted_loss": false,
|
||||
"is_loss_soft_dice": false,
|
||||
"data_is_provided": false,
|
||||
"dir_train": "/home/vahid/extracted_lines/1919_bin/train",
|
||||
"dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new",
|
||||
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
|
||||
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
||||
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
||||
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
|
||||
|
||||
}
|
||||
|
|
@ -8,3 +8,6 @@ tensorflow-addons # for connected_components, depublished and only compatible wi
|
|||
tensorflow < 2.16 # for tensorflow-addons, so only needed in training
|
||||
tf_data < 2.16 # for tensorflow-addons, so only needed in training
|
||||
protobuf < 5 # for tensorflow-addons, so only needed in training
|
||||
torch
|
||||
transformers <= 4.30.2 ; python_version < '3.10'
|
||||
transformers >= 5 ; python_version >= '3.10'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue