integrating transformer ocr

This commit is contained in:
vahidrezanezhad 2026-02-03 19:45:50 +01:00 committed by kba
parent e9839a8b54
commit 7f86a55ccb
3 changed files with 170 additions and 1 deletions

View file

@ -3,6 +3,7 @@ import sys
import io
import json
import click
from typing import Optional
from tqdm import tqdm
import requests
@ -397,7 +398,7 @@ def run(_config,
f1_threshold_classification=None,
classification_classes_name=None,
## if task=cnn-rnn-ocr
characters_txt_file=None,
characters_txt_file: Optional[str]=None,
color_padding_rotation=False,
thetha_padd=None,
bin_deg=False,
@ -698,6 +699,89 @@ def run(_config,
callbacks=callbacks,
initial_epoch=index_start)
elif task=="transformer-ocr":
import torch
from torch.utils.data import Dataset as TorchDataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
dir_img, dir_lab = get_dirs_or_files(dir_train)
if continue_training:
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
else:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
# Create a DataLoader
class TransformerOCRTorchDataset(TorchDataset):
"""
Wraps preprocess_imgs in a format consumable by torch
"""
def __init__(self, config, dir_img, dir_lab, char_to_num):
self.samples = list(
preprocess_imgs(
config,
dir_img,
dir_lab,
char_to_num=char_to_num
)
)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
image, label = self.samples[idx]
return {
"image": torch.as_tensor(image, dtype=torch.float32),
"label": torch.as_tensor(label, dtype=torch.long),
}
assert characters_txt_file
with open(characters_txt_file, 'r') as char_txt_f:
characters = json.load(char_txt_f)
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
dataset = TransformerOCRTorchDataset(_config, dir_img, dir_lab, char_to_num)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
train_dataset = data_loader.dataset
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_len
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
num_train_epochs=n_epochs,
learning_rate=learning_rate,
per_device_train_batch_size=n_batch,
fp16=True,
output_dir=dir_output,
logging_steps=2,
save_steps=save_interval,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
train_dataset=train_dataset,
data_collator=default_data_collator,
)
trainer.train()
elif task=='classification':
if continue_training:
model = load_model(dir_of_start_model, compile=False)

View file

@ -0,0 +1,82 @@
{
"backbone_type" : "transformer",
"task": "transformer-ocr",
"n_classes" : 2,
"max_len": 192,
"n_epochs" : 1,
"input_height" : 32,
"input_width" : 512,
"weight_decay" : 1e-6,
"n_batch" : 1,
"learning_rate": 1e-5,
"save_interval": 1500,
"patches" : false,
"pretraining" : false,
"augmentation" : true,
"flip_aug" : false,
"blur_aug" : true,
"scaling" : false,
"adding_rgb_background": true,
"adding_rgb_foreground": true,
"add_red_textlines": true,
"white_noise_strap": true,
"textline_right_in_depth": true,
"textline_left_in_depth": true,
"textline_up_in_depth": true,
"textline_down_in_depth": true,
"textline_right_in_depth_bin": true,
"textline_left_in_depth_bin": true,
"textline_up_in_depth_bin": true,
"textline_down_in_depth_bin": true,
"bin_deg": true,
"textline_skewing": true,
"textline_skewing_bin": true,
"channels_shuffling": true,
"degrading": true,
"brightening": true,
"binarization" : true,
"pepper_aug": true,
"pepper_bin_aug": true,
"image_inversion": true,
"scaling_bluring" : false,
"scaling_binarization" : false,
"scaling_flip" : false,
"rotation": false,
"color_padding_rotation": true,
"padding_white": true,
"rotation_not_90": true,
"transformer_num_patches_xy": [56, 56],
"transformer_patchsize_x": 4,
"transformer_patchsize_y": 4,
"transformer_projection_dim": 64,
"transformer_mlp_head_units": [128, 64],
"transformer_layers": 1,
"transformer_num_heads": 1,
"transformer_cnn_first": false,
"blur_k" : ["blur","gauss","median"],
"padd_colors" : ["white", "black"],
"scales" : [0.6, 0.7, 0.8, 0.9],
"brightness" : [1.3, 1.5, 1.7, 2],
"degrade_scales" : [0.2, 0.4],
"pepper_indexes": [0.01, 0.005],
"skewing_amplitudes" : [5, 8],
"flip_index" : [0, 1, -1],
"shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]],
"thetha" : [0.1, 0.2, -0.1, -0.2],
"thetha_padd": [-0.6, -1, -1.4, -1.8, 0.6, 1, 1.4, 1.8],
"white_padds" : [0.1, 0.3, 0.5, 0.7, 0.9],
"number_of_backgrounds_per_image": 2,
"continue_training": false,
"index_start" : 0,
"dir_of_start_model" : " ",
"weighted_loss": false,
"is_loss_soft_dice": false,
"data_is_provided": false,
"dir_train": "/home/vahid/extracted_lines/1919_bin/train",
"dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new",
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
}

View file

@ -8,3 +8,6 @@ tensorflow-addons # for connected_components, depublished and only compatible wi
tensorflow < 2.16 # for tensorflow-addons, so only needed in training
tf_data < 2.16 # for tensorflow-addons, so only needed in training
protobuf < 5 # for tensorflow-addons, so only needed in training
torch
transformers <= 4.30.2 ; python_version < '3.10'
transformers >= 5 ; python_version >= '3.10'