mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
integrating transformer ocr
This commit is contained in:
parent
586077fbcd
commit
60f0fb541d
4 changed files with 552 additions and 13 deletions
|
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from eynollah.training.metrics import (
|
from eynollah.training.metrics import (
|
||||||
|
|
@ -27,7 +28,8 @@ from eynollah.training.utils import (
|
||||||
generate_data_from_folder_training,
|
generate_data_from_folder_training,
|
||||||
get_one_hot,
|
get_one_hot,
|
||||||
provide_patches,
|
provide_patches,
|
||||||
return_number_of_total_training_data
|
return_number_of_total_training_data,
|
||||||
|
OCRDatasetYieldAugmentations
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
@ -41,8 +43,13 @@ from sklearn.metrics import f1_score
|
||||||
from tensorflow.keras.callbacks import Callback
|
from tensorflow.keras.callbacks import Callback
|
||||||
from tensorflow.keras.layers import StringLookup
|
from tensorflow.keras.layers import StringLookup
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
import torch
|
||||||
|
from transformers import TrOCRProcessor
|
||||||
|
import evaluate
|
||||||
|
from transformers import default_data_collator
|
||||||
|
from transformers import VisionEncoderDecoderModel
|
||||||
|
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
class SaveWeightsAfterSteps(Callback):
|
class SaveWeightsAfterSteps(Callback):
|
||||||
def __init__(self, save_interval, save_path, _config):
|
def __init__(self, save_interval, save_path, _config):
|
||||||
|
|
@ -559,6 +566,121 @@ def run(
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
|
||||||
|
elif task=="transformer-ocr":
|
||||||
|
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
||||||
|
|
||||||
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||||
|
|
||||||
|
ls_files_images = os.listdir(dir_img)
|
||||||
|
|
||||||
|
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
|
||||||
|
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
|
||||||
|
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
|
||||||
|
|
||||||
|
len_dataset = aug_multip*len(ls_files_images)
|
||||||
|
|
||||||
|
dataset = OCRDatasetYieldAugmentations(
|
||||||
|
dir_img=dir_img,
|
||||||
|
dir_img_bin=dir_img_bin,
|
||||||
|
dir_lab=dir_lab,
|
||||||
|
processor=processor,
|
||||||
|
max_target_length=max_len,
|
||||||
|
augmentation = augmentation,
|
||||||
|
binarization = binarization,
|
||||||
|
add_red_textlines = add_red_textlines,
|
||||||
|
white_noise_strap = white_noise_strap,
|
||||||
|
adding_rgb_foreground = adding_rgb_foreground,
|
||||||
|
adding_rgb_background = adding_rgb_background,
|
||||||
|
bin_deg = bin_deg,
|
||||||
|
blur_aug = blur_aug,
|
||||||
|
brightening = brightening,
|
||||||
|
padding_white = padding_white,
|
||||||
|
color_padding_rotation = color_padding_rotation,
|
||||||
|
rotation_not_90 = rotation_not_90,
|
||||||
|
degrading = degrading,
|
||||||
|
channels_shuffling = channels_shuffling,
|
||||||
|
textline_skewing = textline_skewing,
|
||||||
|
textline_skewing_bin = textline_skewing_bin,
|
||||||
|
textline_right_in_depth = textline_right_in_depth,
|
||||||
|
textline_left_in_depth = textline_left_in_depth,
|
||||||
|
textline_up_in_depth = textline_up_in_depth,
|
||||||
|
textline_down_in_depth = textline_down_in_depth,
|
||||||
|
textline_right_in_depth_bin = textline_right_in_depth_bin,
|
||||||
|
textline_left_in_depth_bin = textline_left_in_depth_bin,
|
||||||
|
textline_up_in_depth_bin = textline_up_in_depth_bin,
|
||||||
|
textline_down_in_depth_bin = textline_down_in_depth_bin,
|
||||||
|
pepper_aug = pepper_aug,
|
||||||
|
pepper_bin_aug = pepper_bin_aug,
|
||||||
|
list_all_possible_background_images=list_all_possible_background_images,
|
||||||
|
list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs,
|
||||||
|
blur_k = blur_k,
|
||||||
|
degrade_scales = degrade_scales,
|
||||||
|
white_padds = white_padds,
|
||||||
|
thetha_padd = thetha_padd,
|
||||||
|
thetha = thetha,
|
||||||
|
brightness = brightness,
|
||||||
|
padd_colors = padd_colors,
|
||||||
|
number_of_backgrounds_per_image = number_of_backgrounds_per_image,
|
||||||
|
shuffle_indexes = shuffle_indexes,
|
||||||
|
pepper_indexes = pepper_indexes,
|
||||||
|
skewing_amplitudes = skewing_amplitudes,
|
||||||
|
dir_rgb_backgrounds = dir_rgb_backgrounds,
|
||||||
|
dir_rgb_foregrounds = dir_rgb_foregrounds,
|
||||||
|
len_data=len_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a DataLoader
|
||||||
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
||||||
|
train_dataset = data_loader.dataset
|
||||||
|
|
||||||
|
|
||||||
|
if continue_training:
|
||||||
|
model = VisionEncoderDecoderModel.from_pretrained(dir_of_start_model)
|
||||||
|
else:
|
||||||
|
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
|
||||||
|
|
||||||
|
|
||||||
|
# set special tokens used for creating the decoder_input_ids from the labels
|
||||||
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
||||||
|
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||||
|
# make sure vocab size is set correctly
|
||||||
|
model.config.vocab_size = model.config.decoder.vocab_size
|
||||||
|
|
||||||
|
# set beam search parameters
|
||||||
|
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
||||||
|
model.config.max_length = max_len
|
||||||
|
model.config.early_stopping = True
|
||||||
|
model.config.no_repeat_ngram_size = 3
|
||||||
|
model.config.length_penalty = 2.0
|
||||||
|
model.config.num_beams = 4
|
||||||
|
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(
|
||||||
|
predict_with_generate=True,
|
||||||
|
num_train_epochs=n_epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
per_device_train_batch_size=n_batch,
|
||||||
|
fp16=True,
|
||||||
|
output_dir=dir_output,
|
||||||
|
logging_steps=2,
|
||||||
|
save_steps=save_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
cer_metric = evaluate.load("cer")
|
||||||
|
|
||||||
|
# instantiate trainer
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=processor.feature_extractor,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
data_collator=default_data_collator,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
elif task=='classification':
|
elif task=='classification':
|
||||||
configuration()
|
configuration()
|
||||||
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,16 @@ from scipy.ndimage.interpolation import map_coordinates
|
||||||
from scipy.ndimage.filters import gaussian_filter
|
from scipy.ndimage.filters import gaussian_filter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import imutils
|
import imutils
|
||||||
import tensorflow as tf
|
##import tensorflow as tf
|
||||||
from tensorflow.keras.utils import to_categorical
|
##from tensorflow.keras.utils import to_categorical
|
||||||
from PIL import Image, ImageFile, ImageEnhance
|
from PIL import Image, ImageFile, ImageEnhance
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
|
||||||
def vectorize_label(label, char_to_num, padding_token, max_len):
|
def vectorize_label(label, char_to_num, padding_token, max_len):
|
||||||
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
|
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
|
||||||
length = tf.shape(label)[0]
|
length = tf.shape(label)[0]
|
||||||
|
|
@ -76,6 +80,7 @@ def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
|
||||||
|
|
||||||
return noisy_image
|
return noisy_image
|
||||||
|
|
||||||
|
|
||||||
def invert_image(img):
|
def invert_image(img):
|
||||||
img_inv = 255 - img
|
img_inv = 255 - img
|
||||||
return img_inv
|
return img_inv
|
||||||
|
|
@ -1668,3 +1673,411 @@ def return_multiplier_based_on_augmnentations(
|
||||||
aug_multip += len(pepper_indexes)
|
aug_multip += len(pepper_indexes)
|
||||||
|
|
||||||
return aug_multip
|
return aug_multip
|
||||||
|
|
||||||
|
|
||||||
|
class OCRDatasetYieldAugmentations(IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dir_img,
|
||||||
|
dir_img_bin,
|
||||||
|
dir_lab,
|
||||||
|
processor,
|
||||||
|
max_target_length=128,
|
||||||
|
augmentation = None,
|
||||||
|
binarization = None,
|
||||||
|
add_red_textlines = None,
|
||||||
|
white_noise_strap = None,
|
||||||
|
adding_rgb_foreground = None,
|
||||||
|
adding_rgb_background = None,
|
||||||
|
bin_deg = None,
|
||||||
|
blur_aug = None,
|
||||||
|
brightening = None,
|
||||||
|
padding_white = None,
|
||||||
|
color_padding_rotation = None,
|
||||||
|
rotation_not_90 = None,
|
||||||
|
degrading = None,
|
||||||
|
channels_shuffling = None,
|
||||||
|
textline_skewing = None,
|
||||||
|
textline_skewing_bin = None,
|
||||||
|
textline_right_in_depth = None,
|
||||||
|
textline_left_in_depth = None,
|
||||||
|
textline_up_in_depth = None,
|
||||||
|
textline_down_in_depth = None,
|
||||||
|
textline_right_in_depth_bin = None,
|
||||||
|
textline_left_in_depth_bin = None,
|
||||||
|
textline_up_in_depth_bin = None,
|
||||||
|
textline_down_in_depth_bin = None,
|
||||||
|
pepper_aug = None,
|
||||||
|
pepper_bin_aug = None,
|
||||||
|
list_all_possible_background_images=None,
|
||||||
|
list_all_possible_foreground_rgbs=None,
|
||||||
|
blur_k = None,
|
||||||
|
degrade_scales = None,
|
||||||
|
white_padds = None,
|
||||||
|
thetha_padd = None,
|
||||||
|
thetha = None,
|
||||||
|
brightness = None,
|
||||||
|
padd_colors = None,
|
||||||
|
number_of_backgrounds_per_image = None,
|
||||||
|
shuffle_indexes = None,
|
||||||
|
pepper_indexes = None,
|
||||||
|
skewing_amplitudes = None,
|
||||||
|
dir_rgb_backgrounds = None,
|
||||||
|
dir_rgb_foregrounds = None,
|
||||||
|
len_data=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
images_dir (str): Path to the directory containing images.
|
||||||
|
labels_dir (str): Path to the directory containing label text files.
|
||||||
|
tokenizer: Tokenizer for processing labels.
|
||||||
|
transform: Transformations applied after augmentation (e.g., ToTensor, normalization).
|
||||||
|
image_size (tuple): Size to resize images to.
|
||||||
|
max_seq_len (int): Maximum sequence length for tokenized labels.
|
||||||
|
scales (list or None): List of scale factors to apply.
|
||||||
|
"""
|
||||||
|
self.dir_img = dir_img
|
||||||
|
self.dir_img_bin = dir_img_bin
|
||||||
|
self.dir_lab = dir_lab
|
||||||
|
self.processor = processor
|
||||||
|
self.max_target_length = max_target_length
|
||||||
|
#self.scales = scales if scales else []
|
||||||
|
|
||||||
|
self.augmentation = augmentation
|
||||||
|
self.binarization = binarization
|
||||||
|
self.add_red_textlines = add_red_textlines
|
||||||
|
self.white_noise_strap = white_noise_strap
|
||||||
|
self.adding_rgb_foreground = adding_rgb_foreground
|
||||||
|
self.adding_rgb_background = adding_rgb_background
|
||||||
|
self.bin_deg = bin_deg
|
||||||
|
self.blur_aug = blur_aug
|
||||||
|
self.brightening = brightening
|
||||||
|
self.padding_white = padding_white
|
||||||
|
self.color_padding_rotation = color_padding_rotation
|
||||||
|
self.rotation_not_90 = rotation_not_90
|
||||||
|
self.degrading = degrading
|
||||||
|
self.channels_shuffling = channels_shuffling
|
||||||
|
self.textline_skewing = textline_skewing
|
||||||
|
self.textline_skewing_bin = textline_skewing_bin
|
||||||
|
self.textline_right_in_depth = textline_right_in_depth
|
||||||
|
self.textline_left_in_depth = textline_left_in_depth
|
||||||
|
self.textline_up_in_depth = textline_up_in_depth
|
||||||
|
self.textline_down_in_depth = textline_down_in_depth
|
||||||
|
self.textline_right_in_depth_bin = textline_right_in_depth_bin
|
||||||
|
self.textline_left_in_depth_bin = textline_left_in_depth_bin
|
||||||
|
self.textline_up_in_depth_bin = textline_up_in_depth_bin
|
||||||
|
self.textline_down_in_depth_bin = textline_down_in_depth_bin
|
||||||
|
self.pepper_aug = pepper_aug
|
||||||
|
self.pepper_bin_aug = pepper_bin_aug
|
||||||
|
self.list_all_possible_background_images=list_all_possible_background_images
|
||||||
|
self.list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs
|
||||||
|
self.blur_k = blur_k
|
||||||
|
self.degrade_scales = degrade_scales
|
||||||
|
self.white_padds = white_padds
|
||||||
|
self.thetha_padd = thetha_padd
|
||||||
|
self.thetha = thetha
|
||||||
|
self.brightness = brightness
|
||||||
|
self.padd_colors = padd_colors
|
||||||
|
self.number_of_backgrounds_per_image = number_of_backgrounds_per_image
|
||||||
|
self.shuffle_indexes = shuffle_indexes
|
||||||
|
self.pepper_indexes = pepper_indexes
|
||||||
|
self.skewing_amplitudes = skewing_amplitudes
|
||||||
|
self.dir_rgb_backgrounds = dir_rgb_backgrounds
|
||||||
|
self.dir_rgb_foregrounds = dir_rgb_foregrounds
|
||||||
|
self.image_files = os.listdir(dir_img)#sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
|
||||||
|
self.len_data = len_data
|
||||||
|
#assert len(self.image_files) == len(self.label_files), "Number of images and labels must match!"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.len_data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for img_file in self.image_files:
|
||||||
|
# Load image
|
||||||
|
f_name = img_file.split('.')[0]
|
||||||
|
|
||||||
|
txt_inp = open(os.path.join(self.dir_lab, f_name+'.txt'),'r').read().split('\n')[0]
|
||||||
|
|
||||||
|
img = cv2.imread(os.path.join(self.dir_img, img_file))
|
||||||
|
img = img.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
if self.dir_img_bin:
|
||||||
|
img_bin_corr = cv2.imread(os.path.join(self.dir_img_bin, f_name+'.png') )
|
||||||
|
img_bin_corr = img_bin_corr.astype(np.uint8)
|
||||||
|
else:
|
||||||
|
img_bin_corr = None
|
||||||
|
|
||||||
|
|
||||||
|
labels = self.processor.tokenizer(txt_inp,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.max_target_length).input_ids
|
||||||
|
|
||||||
|
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
|
||||||
|
|
||||||
|
|
||||||
|
if self.augmentation:
|
||||||
|
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.color_padding_rotation:
|
||||||
|
for index, thetha_ind in enumerate(self.thetha_padd):
|
||||||
|
for padd_col in self.padd_colors:
|
||||||
|
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.rotation_not_90:
|
||||||
|
for index, thetha_ind in enumerate(self.thetha):
|
||||||
|
img_out = rotation_not_90_func_single_image(img, thetha_ind)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.blur_aug:
|
||||||
|
for index, blur_type in enumerate(self.blur_k):
|
||||||
|
img_out = bluring(img, blur_type)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.degrading:
|
||||||
|
for index, deg_scale_ind in enumerate(self.degrade_scales):
|
||||||
|
try:
|
||||||
|
img_out = do_degrading(img, deg_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.bin_deg:
|
||||||
|
for index, deg_scale_ind in enumerate(self.degrade_scales):
|
||||||
|
try:
|
||||||
|
img_out = self.do_degrading(img_bin_corr, deg_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.brightening:
|
||||||
|
for index, bright_scale_ind in enumerate(self.brightness):
|
||||||
|
try:
|
||||||
|
img_out = do_brightening(dir_img, bright_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.padding_white:
|
||||||
|
for index, padding_size in enumerate(self.white_padds):
|
||||||
|
for padd_col in self.padd_colors:
|
||||||
|
img_out = do_padding_for_ocr(img, padding_size, padd_col)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.adding_rgb_foreground:
|
||||||
|
for i_n in range(self.number_of_backgrounds_per_image):
|
||||||
|
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
|
||||||
|
foreground_rgb_chosen_name = random.choice(self.list_all_possible_foreground_rgbs)
|
||||||
|
|
||||||
|
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||||
|
foreground_rgb_chosen = np.load(self.dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name)
|
||||||
|
|
||||||
|
img_out = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if self.adding_rgb_background:
|
||||||
|
for i_n in range(self.number_of_backgrounds_per_image):
|
||||||
|
background_image_chosen_name = random.choice(self.list_all_possible_background_images)
|
||||||
|
img_rgb_background_chosen = cv2.imread(self.dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||||
|
img_out = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.binarization:
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_bin_corr), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.channels_shuffling:
|
||||||
|
for shuffle_index in self.shuffle_indexes:
|
||||||
|
img_out = return_shuffled_channels(img, shuffle_index)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.add_red_textlines:
|
||||||
|
img_out = return_image_with_red_elements(img, img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.white_noise_strap:
|
||||||
|
img_out = return_image_with_strapped_white_noises(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.textline_skewing:
|
||||||
|
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
|
||||||
|
try:
|
||||||
|
img_out = do_deskewing(img, des_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.textline_skewing_bin:
|
||||||
|
for index, des_scale_ind in enumerate(self.skewing_amplitudes):
|
||||||
|
try:
|
||||||
|
img_out = do_deskewing(img_bin_corr, des_scale_ind)
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_left_in_depth:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img, 'left')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_left_in_depth_bin:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img_bin_corr, 'left')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_right_in_depth:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img, 'right')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_right_in_depth_bin:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img_bin_corr, 'right')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_up_in_depth:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img, 'up')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_up_in_depth_bin:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img_bin_corr, 'up')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_down_in_depth:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img, 'down')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.textline_down_in_depth_bin:
|
||||||
|
try:
|
||||||
|
img_out = do_direction_in_depth(img_bin_corr, 'down')
|
||||||
|
except:
|
||||||
|
img_out = np.copy(img_bin_corr)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
if self.pepper_bin_aug:
|
||||||
|
for index, pepper_ind in enumerate(self.pepper_indexes):
|
||||||
|
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
if self.pepper_aug:
|
||||||
|
for index, pepper_ind in enumerate(self.pepper_indexes):
|
||||||
|
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
|
||||||
|
img_out = img_out.astype(np.uint8)
|
||||||
|
pixel_values = self.processor(Image.fromarray(img_out), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
pixel_values = self.processor(Image.fromarray(img), return_tensors="pt").pixel_values
|
||||||
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
|
||||||
|
yield encoding
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
{
|
{
|
||||||
"backbone_type" : "transformer",
|
"backbone_type" : "transformer",
|
||||||
"task": "cnn-rnn-ocr",
|
"task": "transformer-ocr",
|
||||||
"n_classes" : 2,
|
"n_classes" : 2,
|
||||||
"max_len": 280,
|
"max_len": 192,
|
||||||
"n_epochs" : 3,
|
"n_epochs" : 1,
|
||||||
"input_height" : 32,
|
"input_height" : 32,
|
||||||
"input_width" : 512,
|
"input_width" : 512,
|
||||||
"weight_decay" : 1e-6,
|
"weight_decay" : 1e-6,
|
||||||
"n_batch" : 4,
|
"n_batch" : 1,
|
||||||
"learning_rate": 1e-5,
|
"learning_rate": 1e-5,
|
||||||
"save_interval": 1500,
|
"save_interval": 1500,
|
||||||
"patches" : false,
|
"patches" : false,
|
||||||
"pretraining" : true,
|
"pretraining" : false,
|
||||||
"augmentation" : true,
|
"augmentation" : true,
|
||||||
"flip_aug" : false,
|
"flip_aug" : false,
|
||||||
"blur_aug" : true,
|
"blur_aug" : true,
|
||||||
|
|
@ -77,7 +77,6 @@
|
||||||
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
|
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
|
||||||
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
||||||
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
||||||
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin",
|
"dir_img_bin": "/home/vahid/extracted_lines/1919_bin/images_bin"
|
||||||
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,3 +4,8 @@ numpy <1.24.0
|
||||||
tqdm
|
tqdm
|
||||||
imutils
|
imutils
|
||||||
scipy
|
scipy
|
||||||
|
torch
|
||||||
|
evaluate
|
||||||
|
accelerate
|
||||||
|
jiwer
|
||||||
|
transformers <= 4.30.2
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue