integrating transformer ocr

This commit is contained in:
vahidrezanezhad 2026-02-03 19:45:50 +01:00
parent 586077fbcd
commit 60f0fb541d
4 changed files with 552 additions and 13 deletions

View file

@ -1,7 +1,8 @@
import os
import sys
import json
import numpy as np
import cv2
import click
from eynollah.training.metrics import (
@ -27,7 +28,8 @@ from eynollah.training.utils import (
generate_data_from_folder_training,
get_one_hot,
provide_patches,
return_number_of_total_training_data
return_number_of_total_training_data,
OCRDatasetYieldAugmentations
)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@ -41,8 +43,13 @@ from sklearn.metrics import f1_score
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import StringLookup
import numpy as np
import cv2
import torch
from transformers import TrOCRProcessor
import evaluate
from transformers import default_data_collator
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
class SaveWeightsAfterSteps(Callback):
def __init__(self, save_interval, save_path, _config):
@ -559,6 +566,121 @@ def run(
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
elif task=="transformer-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
ls_files_images = os.listdir(dir_img)
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
len_dataset = aug_multip*len(ls_files_images)
dataset = OCRDatasetYieldAugmentations(
dir_img=dir_img,
dir_img_bin=dir_img_bin,
dir_lab=dir_lab,
processor=processor,
max_target_length=max_len,
augmentation = augmentation,
binarization = binarization,
add_red_textlines = add_red_textlines,
white_noise_strap = white_noise_strap,
adding_rgb_foreground = adding_rgb_foreground,
adding_rgb_background = adding_rgb_background,
bin_deg = bin_deg,
blur_aug = blur_aug,
brightening = brightening,
padding_white = padding_white,
color_padding_rotation = color_padding_rotation,
rotation_not_90 = rotation_not_90,
degrading = degrading,
channels_shuffling = channels_shuffling,
textline_skewing = textline_skewing,
textline_skewing_bin = textline_skewing_bin,
textline_right_in_depth = textline_right_in_depth,
textline_left_in_depth = textline_left_in_depth,
textline_up_in_depth = textline_up_in_depth,
textline_down_in_depth = textline_down_in_depth,
textline_right_in_depth_bin = textline_right_in_depth_bin,
textline_left_in_depth_bin = textline_left_in_depth_bin,
textline_up_in_depth_bin = textline_up_in_depth_bin,
textline_down_in_depth_bin = textline_down_in_depth_bin,
pepper_aug = pepper_aug,
pepper_bin_aug = pepper_bin_aug,
list_all_possible_background_images=list_all_possible_background_images,
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
blur_k = blur_k,
degrade_scales = degrade_scales,
white_padds = white_padds,
thetha_padd = thetha_padd,
thetha = thetha,
brightness = brightness,
padd_colors = padd_colors,
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
shuffle_indexes = shuffle_indexes,
pepper_indexes = pepper_indexes,
skewing_amplitudes = skewing_amplitudes,
dir_rgb_backgrounds = dir_rgb_backgrounds,
dir_rgb_foregrounds = dir_rgb_foregrounds,
len_data=len_dataset,
)
# Create a DataLoader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
train_dataset = data_loader.dataset
if continue_training:
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
else:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_len
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
num_train_epochs=n_epochs,
learning_rate=learning_rate,
per_device_train_batch_size=n_batch,
fp16=True,
output_dir=dir_output,
logging_steps=2,
save_steps=save_interval,
)
cer_metric = evaluate.load("cer")
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
train_dataset=train_dataset,
data_collator=default_data_collator,
)
trainer.train()
elif task=='classification':
configuration()
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)

View file

@ -9,12 +9,16 @@ from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
from tqdm import tqdm
import imutils
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
##import tensorflow as tf
##from tensorflow.keras.utils import to_categorical
from PIL import Image, ImageFile, ImageEnhance
import torch
from torch.utils.data import IterableDataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
def vectorize_label(label, char_to_num, padding_token, max_len):
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
length = tf.shape(label)[0]
@ -76,6 +80,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
return noisy_image
def invert_image(img):
img_inv = 255 - img
return img_inv
@ -1668,3 +1673,411 @@ def return_multiplier_based_on_augmnentations(
aug_multip += len(pepper_indexes)
return aug_multip
class OCRDatasetYieldAugmentations(IterableDataset):
def __init__(
self,
dir_img,
dir_img_bin,
dir_lab,
processor,
max_target_length=128,
augmentation = None,
binarization = None,
add_red_textlines = None,
white_noise_strap = None,
adding_rgb_foreground = None,
adding_rgb_background = None,
bin_deg = None,
blur_aug = None,
brightening = None,
padding_white = None,
color_padding_rotation = None,
rotation_not_90 = None,
degrading = None,
channels_shuffling = None,
textline_skewing = None,
textline_skewing_bin = None,
textline_right_in_depth = None,
textline_left_in_depth = None,
textline_up_in_depth = None,
textline_down_in_depth = None,
textline_right_in_depth_bin = None,
textline_left_in_depth_bin = None,
textline_up_in_depth_bin = None,
textline_down_in_depth_bin = None,
pepper_aug = None,
pepper_bin_aug = None,
list_all_possible_background_images=None,
list_all_possible_foreground_rgbs=None,
blur_k = None,
degrade_scales = None,
white_padds = None,
thetha_padd = None,
thetha = None,
brightness = None,
padd_colors = None,
number_of_backgrounds_per_image = None,
shuffle_indexes = None,
pepper_indexes = None,
skewing_amplitudes = None,
dir_rgb_backgrounds = None,
dir_rgb_foregrounds = None,
len_data=None,
):
"""
Args:
images_dir (str): Path to the directory containing images.
labels_dir (str): Path to the directory containing label text files.
tokenizer: Tokenizer for processing labels.
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
image_size (tuple): Size to resize images to.
max_seq_len (int): Maximum sequence length for tokenized labels.
scales (list or None): List of scale factors to apply.
"""
self.dir_img = dir_img
self.dir_img_bin = dir_img_bin
self.dir_lab = dir_lab
self.processor = processor
self.max_target_length = max_target_length
#self.scales = scales if scales else []
self.augmentation = augmentation
self.binarization = binarization
self.add_red_textlines = add_red_textlines
self.white_noise_strap = white_noise_strap
self.adding_rgb_foreground = adding_rgb_foreground
self.adding_rgb_background = adding_rgb_background
self.bin_deg = bin_deg
self.blur_aug = blur_aug
self.brightening = brightening
self.padding_white = padding_white
self.color_padding_rotation = color_padding_rotation
self.rotation_not_90 = rotation_not_90
self.degrading = degrading
self.channels_shuffling = channels_shuffling
self.textline_skewing = textline_skewing
self.textline_skewing_bin = textline_skewing_bin
self.textline_right_in_depth = textline_right_in_depth
self.textline_left_in_depth = textline_left_in_depth
self.textline_up_in_depth = textline_up_in_depth
self.textline_down_in_depth = textline_down_in_depth
self.textline_right_in_depth_bin = textline_right_in_depth_bin
self.textline_left_in_depth_bin = textline_left_in_depth_bin
self.textline_up_in_depth_bin = textline_up_in_depth_bin
self.textline_down_in_depth_bin = textline_down_in_depth_bin
self.pepper_aug = pepper_aug
self.pepper_bin_aug = pepper_bin_aug
self.list_all_possible_background_images=list_all_possible_background_images
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
self.blur_k = blur_k
self.degrade_scales = degrade_scales
self.white_padds = white_padds
self.thetha_padd = thetha_padd
self.thetha = thetha
self.brightness = brightness
self.padd_colors = padd_colors
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
self.shuffle_indexes = shuffle_indexes
self.pepper_indexes = pepper_indexes
self.skewing_amplitudes = skewing_amplitudes
self.dir_rgb_backgrounds = dir_rgb_backgrounds
self.dir_rgb_foregrounds = dir_rgb_foregrounds
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
self.len_data = len_data
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
def __len__(self):
return self.len_data
def __iter__(self):
for img_file in self.image_files:
# Load image
f_name = img_file.split('.')[0]
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
img = cv2.imread(os.path.join(self.dir_img, img_file))
img = img.astype(np.uint8)
if self.dir_img_bin:
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
img_bin_corr = img_bin_corr.astype(np.uint8)
else:
img_bin_corr = None
labels = self.processor.tokenizer(txt_inp,
padding="max_length",
max_length=self.max_target_length).input_ids
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
if self.augmentation:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.color_padding_rotation:
for index, thetha_ind in enumerate(self.thetha_padd):
for padd_col in self.padd_colors:
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.rotation_not_90:
for index, thetha_ind in enumerate(self.thetha):
img_out = rotation_not_90_func_single_image(img, thetha_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.blur_aug:
for index, blur_type in enumerate(self.blur_k):
img_out = bluring(img, blur_type)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.degrading:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = do_degrading(img, deg_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.bin_deg:
for index, deg_scale_ind in enumerate(self.degrade_scales):
try:
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.brightening:
for index, bright_scale_ind in enumerate(self.brightness):
try:
img_out = do_brightening(dir_img, bright_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.padding_white:
for index, padding_size in enumerate(self.white_padds):
for padd_col in self.padd_colors:
img_out = do_padding_for_ocr(img, padding_size, padd_col)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_foreground:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.adding_rgb_background:
for i_n in range(self.number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.binarization:
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.channels_shuffling:
for shuffle_index in self.shuffle_indexes:
img_out = return_shuffled_channels(img, shuffle_index)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.add_red_textlines:
img_out = return_image_with_red_elements(img, img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.white_noise_strap:
img_out = return_image_with_strapped_white_noises(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img, des_scale_ind)
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_skewing_bin:
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
try:
img_out = do_deskewing(img_bin_corr, des_scale_ind)
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth:
try:
img_out = do_direction_in_depth(img, 'left')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_left_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'left')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth:
try:
img_out = do_direction_in_depth(img, 'right')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_right_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'right')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth:
try:
img_out = do_direction_in_depth(img, 'up')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_up_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'up')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth:
try:
img_out = do_direction_in_depth(img, 'down')
except:
img_out = np.copy(img)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.textline_down_in_depth_bin:
try:
img_out = do_direction_in_depth(img_bin_corr, 'down')
except:
img_out = np.copy(img_bin_corr)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_bin_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
if self.pepper_aug:
for index, pepper_ind in enumerate(self.pepper_indexes):
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
img_out = img_out.astype(np.uint8)
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding
else:
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
yield encoding

View file

@ -1,17 +1,17 @@
{
"backbone_type" : "transformer",
"task": "cnn-rnn-ocr",
"task": "transformer-ocr",
"n_classes" : 2,
"max_len": 280,
"n_epochs" : 3,
"max_len": 192,
"n_epochs" : 1,
"input_height" : 32,
"input_width" : 512,
"weight_decay" : 1e-6,
"n_batch" : 4,
"n_batch" : 1,
"learning_rate": 1e-5,
"save_interval": 1500,
"patches" : false,
"pretraining" : true,
"pretraining" : false,
"augmentation" : true,
"flip_aug" : false,
"blur_aug" : true,
@ -77,7 +77,6 @@
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin",
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
}

View file

@ -4,3 +4,8 @@ numpy <1.24.0
tqdm
imutils
scipy
torch
evaluate
accelerate
jiwer
transformers <= 4.30.2