cnn-rnn model can be called - model input height and width are dynamic now - data generator is also callable

This commit is contained in:
vahidrezanezhad 2025-12-09 15:30:19 +01:00
parent 59e5a73654
commit 84a72a128b
4 changed files with 283 additions and 239 deletions

View file

@ -13,6 +13,25 @@ resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_
IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1
class CTCLayer(tf.keras.layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = tf.keras.backend.ctc_batch_cost
def call(self, y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
# At test time, just return the computed predictions.
return y_pred
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
@ -759,85 +778,84 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
return model
def cnn_rnn_ocr_model(image_height, image_width, n_classes, max_seq):
input_img = tensorflow.keras.Input(shape=(image_height, image_width, 3), name="image")
labels = tensorflow.keras.layers.Input(name="label", shape=(None,))
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
input_img = tf.keras.Input(shape=(image_height, image_width, 3), name="image")
labels = tf.keras.layers.Input(name="label", shape=(None,))
x = tensorflow.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
x = tensorflow.keras.layers.BatchNormalization(name="bn1")(x)
x = tensorflow.keras.layers.Activation("relu", name="relu1")(x)
x = tensorflow.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(x)
x = tensorflow.keras.layers.BatchNormalization(name="bn2")(x)
x = tensorflow.keras.layers.Activation("relu", name="relu2")(x)
x = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x = tf.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
x = tf.keras.layers.BatchNormalization(name="bn1")(x)
x = tf.keras.layers.Activation("relu", name="relu1")(x)
x = tf.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn2")(x)
x = tf.keras.layers.Activation("relu", name="relu2")(x)
x = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x = tensorflow.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
x = tensorflow.keras.layers.BatchNormalization(name="bn3")(x)
x = tensorflow.keras.layers.Activation("relu", name="relu3")(x)
x = tensorflow.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
x = tensorflow.keras.layers.BatchNormalization(name="bn4")(x)
x = tensorflow.keras.layers.Activation("relu", name="relu4")(x)
x = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x = tf.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn3")(x)
x = tf.keras.layers.Activation("relu", name="relu3")(x)
x = tf.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn4")(x)
x = tf.keras.layers.Activation("relu", name="relu4")(x)
x = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x = tensorflow.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
x = tensorflow.keras.layers.BatchNormalization(name="bn5")(x)
x = tensorflow.keras.layers.Activation("relu", name="relu5")(x)
x = tensorflow.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
x = tensorflow.keras.layers.BatchNormalization(name="bn6")(x)
x = tensorflow.keras.layers.Activation("relu", name="relu6")(x)
x = tensorflow.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x)
x = tf.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn5")(x)
x = tf.keras.layers.Activation("relu", name="relu5")(x)
x = tf.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn6")(x)
x = tf.keras.layers.Activation("relu", name="relu6")(x)
x = tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x)
x = tensorflow.keras.layers.Conv2D(512,kernel_size=(3,3),padding="same")(x)
x = tensorflow.keras.layers.BatchNormalization(name="bn7")(x)
x = tensorflow.keras.layers.Activation("relu", name="relu7")(x)
x = tensorflow.keras.layers.Conv2D(512,kernel_size=(16,1))(x)
x = tensorflow.keras.layers.BatchNormalization(name="bn8")(x)
x = tensorflow.keras.layers.Activation("relu", name="relu8")(x)
x2d = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x4d = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d)
x = tf.keras.layers.Conv2D(image_width,kernel_size=(3,3),padding="same")(x)
x = tf.keras.layers.BatchNormalization(name="bn7")(x)
x = tf.keras.layers.Activation("relu", name="relu7")(x)
x = tf.keras.layers.Conv2D(image_width,kernel_size=(16,1))(x)
x = tf.keras.layers.BatchNormalization(name="bn8")(x)
x = tf.keras.layers.Activation("relu", name="relu8")(x)
x2d = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
x4d = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d)
new_shape = (x.shape[2], x.shape[3])
new_shape2 = (x2d.shape[2], x2d.shape[3])
new_shape4 = (x4d.shape[2], x4d.shape[3])
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
x = tf.keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
x2d = tf.keras.layers.Reshape(target_shape=new_shape2, name="reshape2")(x2d)
x4d = tf.keras.layers.Reshape(target_shape=new_shape4, name="reshape4")(x4d)
x = tensorflow.keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
x2d = tensorflow.keras.layers.Reshape(target_shape=new_shape2, name="reshape2")(x2d)
x4d = tensorflow.keras.layers.Reshape(target_shape=new_shape4, name="reshape4")(x4d)
xrnnorg = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x)
xrnn2d = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
xrnn4d = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
xrnn2d = tf.keras.layers.Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = tf.keras.layers.Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnnorg = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x)
xrnn2d = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x2d)
xrnn4d = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x4d)
xrnn2dup = tf.keras.layers.UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
xrnn4dup = tf.keras.layers.UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
xrnn2d = tensorflow.keras.layers.Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = tensorflow.keras.layers.Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnn2dup = tf.keras.layers.Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
xrnn4dup = tf.keras.layers.Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
addition = tf.keras.layers.Add()([xrnnorg, xrnn2dup, xrnn4dup])
xrnn2dup = tensorflow.keras.layers.UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
xrnn4dup = tensorflow.keras.layers.UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
addition_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
xrnn2dup = tensorflow.keras.layers.Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
xrnn4dup = tensorflow.keras.layers.Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
out = tf.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
out = tf.keras.layers.BatchNormalization(name="bn9")(out)
out = tf.keras.layers.Activation("relu", name="relu9")(out)
#out = tf.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
addition = tensorflow.keras.layers.Add()([xrnnorg, xrnn2dup, xrnn4dup])
addition_rnn = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(addition)
out = tensorflow.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
out = tensorflow.keras.layers.BatchNormalization(name="bn9")(out)
out = tensorflow.keras.layers.Activation("relu", name="relu9")(out)
#out = tensorflow.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
out = tensorflow.keras.layers.Dense(
out = tf.keras.layers.Dense(
n_classes, activation="softmax", name="dense2"
)(out)
# Add CTC layer for calculating CTC loss at each step.
output = CTCLayer(name="ctc_loss")(labels, out)
model = tensorflow.keras.models.Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer")
model = tf.keras.models.Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer")
return model

View file

@ -15,10 +15,13 @@ from eynollah.training.models import (
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model
)
from eynollah.training.utils import (
data_gen,
data_gen_ocr,
return_multiplier_based_on_augmnentations,
generate_arrays_from_folder_reading_order,
generate_data_from_folder_evaluation,
generate_data_from_folder_training,
@ -36,6 +39,7 @@ from tensorflow.keras.models import load_model
from tqdm import tqdm
from sklearn.metrics import f1_score
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import StringLookup
import numpy as np
import cv2
@ -62,6 +66,7 @@ class SaveWeightsAfterSteps(Callback):
print(f"saved model as steps {self.step_count} to {save_file}")
def configuration():
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
@ -89,6 +94,7 @@ def config_params():
input_width = 224 * 1 # Width of model's input in pixels.
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
n_batch = 1 # Number of batches at each iteration.
max_len = None # max len for ocr output.
learning_rate = 1e-4 # Set the learning rate.
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
@ -132,7 +138,9 @@ def config_params():
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
thetha = None # Rotate image by these angles for augmentation.
shuffle_indexes = None
shuffle_indexes = None # List of shuffling indexes like [[0,2,1], [1,2,0], [1,0,2]]
pepper_indexes = None # List of pepper noise indexes like [0.01, 0.005]
skewing_amplitudes = None # List of skewing augmentation amplitudes like [5, 8]
blur_k = None # Blur image for augmentation.
scales = None # Scale patches for augmentation.
padd_colors = None # padding colors. A list elements can be only white and black. like ["white", "black"] or only one of them ["white"]
@ -180,7 +188,7 @@ def run(_config, n_classes, n_epochs, input_height,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds,
dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin,
textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin,
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors):
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors, pepper_indexes, skewing_amplitudes, max_len):
if dir_rgb_backgrounds:
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
@ -409,20 +417,35 @@ def run(_config, n_classes, n_epochs, input_height,
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
##num_to_char = StringLookup(
##vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
##)
padding_token = len(characters) + 5
ls_files_images = os.listdir(dir_img)
train_ds = data_gen_ocr(padding_token, batchsize=n_batch, height=input_height, width=input_width, max_len=max_len, dir_ins=dir_train, ls_files_images,
augmentation, color_padding_rotation, rotation=rotation_not_90, bluring_aug=blurring, degrading, bin_deg, brightening, w_padding=padding_white,
rgb_fging=adding_rgb_foreground, rgb_bkding=adding_rgb_background, binarization, image_inversion, channel_shuffling=channels_shuffling, add_red_textline=add_red_textlines, white_noise_strap,
n_classes = len(char_to_num.get_vocabulary()) + 2
model = cnn_rnn_ocr_model(image_height=input_height, image_width=input_width, n_classes=n_classes, max_seq=max_len)
print(model.summary())
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes)
len_dataset = aug_multip*len(ls_files_images)
train_ds = data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir_train, ls_files_images,
augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg, brightening, padding_white,
adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth,
textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin,
pepper_bin_aug, pepper_aug, deg_scales=degrade_scales, number_of_backgrounds_per_image, thethas=thetha, brightness, padd_colors,
shuffle_indexes, pepper_indexes, skewing_amplitudes)
pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors,
shuffle_indexes, pepper_indexes, skewing_amplitudes, dir_img_bin)
print(len_dataset, 'len_dataset')
elif task=='classification':
configuration()

View file

@ -414,7 +414,7 @@ def generate_data_from_folder_evaluation(path_classes, height, width, n_classes,
return ret_x/255., ret_y
def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes, list_classes):
def generate_data_from_folder_training(path_classes, n_batch, height, width, n_classes, list_classes):
#sub_classes = os.listdir(path_classes)
#n_classes = len(sub_classes)
@ -440,8 +440,8 @@ def generate_data_from_folder_training(path_classes, batchsize, height, width, n
shuffled_labels = np.array(labels)[ids]
shuffled_files = np.array(all_imgs)[ids]
categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ]
ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16)
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
ret_x= np.zeros((n_batch, height,width, 3)).astype(np.int16)
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
batchcount = 0
while True:
for i in range(len(shuffled_files)):
@ -465,11 +465,11 @@ def generate_data_from_folder_training(path_classes, batchsize, height, width, n
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield ret_x, ret_y
ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16)
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
ret_x= np.zeros((n_batch, height,width, 3)).astype(np.int16)
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
batchcount = 0
def do_brightening(img_in_dir, factor):
@ -634,10 +634,10 @@ def IoU(Yi, y_predi):
#print("Mean IoU: {:4.3f}".format(mIoU))
return mIoU
def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batchsize, height, width, n_classes, thetha, augmentation=False):
def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, n_batch, height, width, n_classes, thetha, augmentation=False):
all_labels_files = os.listdir(classes_file_dir)
ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16)
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16)
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
batchcount = 0
while True:
for i in all_labels_files:
@ -652,10 +652,10 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch
ret_y[batchcount, :] = label_class
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
yield ret_x, ret_y
ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16)
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16)
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
batchcount = 0
if augmentation:
@ -670,10 +670,10 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch
ret_y[batchcount, :] = label_class
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
yield ret_x, ret_y
ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16)
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16)
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
batchcount = 0
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
@ -1264,42 +1264,45 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len=None, dir_ins=None, ls_files_images=None,
augmentation=False, color_padding_rotation=False, rotation=False, bluring_aug=False, degrading=False, bin_deg=False, brightening=False, w_padding=False,
rgb_fging=False, rgb_bkding=False, binarization=False, image_inversion=False, channel_shuffling=False, add_red_textline=False, white_noise_strap=False,
textline_skewing=False, textline_skewing_bin=False, textline_left_in_depth=False, textline_left_in_depth_bin=False, textline_right_in_depth=False,
textline_right_in_depth_bin=False, textline_up_in_depth=False, textline_up_in_depth_bin=False, textline_down_in_depth=False, textline_down_in_depth_bin=False,
pepper_bin_aug=False, pepper_aug=False, deg_scales=None, number_of_backgrounds_per_image=None, thethas=None, brightness=None, padd_colors=None,
shuffle_indexes=None, ):
def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir_train, ls_files_images,
augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg, brightening, padding_white,
adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth,
textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin,
pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors,
shuffle_indexes, pepper_indexes, skewing_amplitudes, dir_img_bin=None):
random.shuffle(ls_files_images)
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
while True:
for i in ls_files_images:
f_name = i.split('.')[0]
txt_inp = open(os.path.join(dir_ins, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]
txt_inp = open(os.path.join(dir_train, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]
img = cv2.imread(os.path.join(dir_ins, "images/"+i) )
img_bin_corr = cv2.imread(os.path.join(dir_ins, "images_bin/"+f_name+'.png') )
img = cv2.imread(os.path.join(dir_train, "images/"+i) )
if dir_img_bin:
img_bin_corr = cv2.imread(os.path.join(dir_img_bin, f_name+'.png') )
else:
img_bin_corr = None
if augmentation:
img_out = scale_padd_image_for_ocr(img, height, width)
img_out = scale_padd_image_for_ocr(img, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if color_padding_rotation:
@ -1307,59 +1310,59 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
for padd_col in padd_colors:
img_out = rotation_not_90_func(do_padding(img, 1.2, padd_col), thetha)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if rotation:
for index, thetha in enumerate(thethas):
if rotation_not_90:
for index, thetha in enumerate(thetha):
img_out = rotation_not_90_func(img, thetha)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if bluring_aug:
if blur_aug:
for index, blur_type in enumerate(blurs):
img_out = bluring(img, blur_type)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if degrading:
for index, deg_scale_ind in enumerate(deg_scales):
for index, deg_scale_ind in enumerate(degrade_scales):
try:
img_out = do_degrading(img, deg_scale_ind)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img)
@ -1368,32 +1371,32 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if bin_deg:
for index, deg_scale_ind in enumerate(deg_scales):
for index, deg_scale_ind in enumerate(degrade_scales):
try:
img_out = do_degrading(img_bin_corr, deg_scale_ind)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img_bin_corr)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
@ -1403,39 +1406,39 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
img_out = do_brightening(dir_img, bright_scale_ind)
except:
img_out = np.copy(img)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if w_padding:
if padding_white:
for index, padding_size in enumerate(white_padds):
for padd_col in padd_colors:
img_out = do_padding(img, padding_size, padd_col)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if rgb_fging:
if adding_rgb_foreground:
for i_n in range(number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(list_all_possible_background_images)
foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs)
@ -1445,130 +1448,130 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
img_with_overlayed_background = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, height, width)
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if rgb_bkding:
if adding_rgb_background:
for i_n in range(number_of_backgrounds_per_image):
background_image_chosen_name = random.choice(list_all_possible_background_images)
img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name)
img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, height, width)
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if binarization:
img_out = scale_padd_image_for_ocr(img_bin_corr, height, width)
img_out = scale_padd_image_for_ocr(img_bin_corr, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if image_inversion:
img_out = invert_image(img_bin_corr)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :, :, :] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x = np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_x = np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y = np.zeros((batch_size, max_len)).astype(np.int16)+padding_token
batchcount = 0
if channel_shuffling:
if channels_shuffling:
for shuffle_index in shuffle_indexes:
img_out = return_shuffled_channels(img, shuffle_index)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if add_red_textline:
if add_red_textlines:
img_red_context = return_image_with_red_elements(img, img_bin_corr)
img_out = scale_padd_image_for_ocr(img_red_context, height, width)
img_out = scale_padd_image_for_ocr(img_red_context, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if white_noise_strap:
img_out = return_image_with_strapped_white_noises(img)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_skewing:
for index, des_scale_ind in enumerate(skewing_amplitudes):
try:
img_out = do_deskewing(img, des_scale_ind)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img)
@ -1577,18 +1580,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_skewing_bin:
for index, des_scale_ind in enumerate(skewing_amplitudes):
try:
img_out = do_deskewing(img_bin_corr, des_scale_ind)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img_bin_corr)
@ -1597,18 +1600,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_left_in_depth:
try:
img_out = do_left_in_depth(img)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img)
@ -1617,18 +1620,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_left_in_depth_bin:
try:
img_out = do_left_in_depth(img_bin_corr)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img_bin_corr)
@ -1637,18 +1640,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_right_in_depth:
try:
img_out = do_right_in_depth(img)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img)
@ -1657,18 +1660,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_right_in_depth_bin:
try:
img_out = do_right_in_depth(img_bin_corr)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img_bin_corr)
@ -1677,18 +1680,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_up_in_depth:
try:
img_out = do_up_in_depth(img)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img)
@ -1697,18 +1700,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_up_in_depth_bin:
try:
img_out = do_up_in_depth(img_bin_corr)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img_bin_corr)
@ -1717,18 +1720,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_down_in_depth:
try:
img_out = do_down_in_depth(img)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img)
@ -1737,18 +1740,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if textline_down_in_depth_bin:
try:
img_out = do_down_in_depth(img_bin_corr)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
except:
img_out = np.copy(img_bin_corr)
@ -1757,70 +1760,70 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if pepper_bin_aug:
for index, pepper_ind in enumerate(pepper_indexes):
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
if pepper_aug:
for index, pepper_ind in enumerate(pepper_indexes):
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
img_out = scale_padd_image_for_ocr(img_out, height, width)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
else:
img_out = scale_padd_image_for_ocr(img, height, width)
img_out = scale_padd_image_for_ocr(img, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
batchcount+=1
if batchcount>=batchsize:
if batchcount>=n_batch:
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
def return_muliplier_based_on_augmnentations(augmentation=False, color_padding_rotation=False, rotation=False, bluring_aug=False,
degrading=False, bin_deg=False, brightening=False, w_padding=False,rgb_fging=False, rgb_bkding=False, binarization=False, image_inversion=False, channel_shuffling=False, add_red_textline=False, white_noise_strap=False,
textline_skewing=False, textline_skewing_bin=False, textline_left_in_depth=False, textline_left_in_depth_bin=False, textline_right_in_depth=False, textline_right_in_depth_bin=False, textline_up_in_depth=False, textline_up_in_depth_bin=False, textline_down_in_depth=False, textline_down_in_depth_bin=False, pepper_bin_aug=False, pepper_aug=False, deg_scales=None, number_of_backgrounds_per_image=None, thethas=None, brightness=None, padd_colors=None):
def return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug,
degrading, bin_deg, brightening, padding_white,adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes):
aug_multip = 1
if augmentation:
@ -1828,7 +1831,7 @@ def return_muliplier_based_on_augmnentations(augmentation=False, color_padding_r
aug_multip = aug_multip + 1
if image_inversion:
aug_multip = aug_multip + 1
if add_red_textline:
if add_red_textlines:
aug_multip = aug_multip + 1
if white_noise_strap:
aug_multip = aug_multip + 1
@ -1848,29 +1851,29 @@ def return_muliplier_based_on_augmnentations(augmentation=False, color_padding_r
aug_multip = aug_multip + 1
if textline_down_in_depth_bin:
aug_multip = aug_multip + 1
if rgb_fging:
if adding_rgb_foreground:
aug_multip = aug_multip + number_of_backgrounds_per_image
if rgb_bkding:
if adding_rgb_background:
aug_multip = aug_multip + number_of_backgrounds_per_image
if bin_deg:
aug_multip = aug_multip + len(deg_scales)
aug_multip = aug_multip + len(degrade_scales)
if degrading:
aug_multip = aug_multip + len(deg_scales)
if rotation:
aug_multip = aug_multip + len(thethas)
aug_multip = aug_multip + len(degrade_scales)
if rotation_not_90:
aug_multip = aug_multip + len(thetha)
if textline_skewing:
aug_multip = aug_multip + len(skewing_amplitudes)
if textline_skewing_bin:
aug_multip = aug_multip + len(skewing_amplitudes)
if color_padding_rotation:
aug_multip = aug_multip + len(thetha_padd)*len(padd_colors)
if channel_shuffling:
if channels_shuffling:
aug_multip = aug_multip + len(shuffle_indexes)
if bluring_aug:
if blur_aug:
aug_multip = aug_multip + len(blurs)
if brightening:
aug_multip = aug_multip + len(brightness)
if w_padding:
if padding_white:
aug_multip = aug_multip + len(white_padds)*len(padd_colors)
if pepper_aug:
aug_multip = aug_multip + len(pepper_indexes)

View file

@ -1,17 +1,17 @@
{
"backbone_type" : "transformer",
"task": "segmentation",
"task": "cnn-rnn-ocr",
"n_classes" : 2,
"max_len": 280,
"n_epochs" : 0,
"input_height" : 448,
"input_width" : 448,
"input_height" : 32,
"input_width" : 512,
"weight_decay" : 1e-6,
"n_batch" : 1,
"learning_rate": 1e-4,
"patches" : false,
"pretraining" : true,
"augmentation" : true,
"augmentation" : false,
"flip_aug" : false,
"blur_aug" : false,
"scaling" : false,
@ -49,12 +49,12 @@
"weighted_loss": false,
"is_loss_soft_dice": false,
"data_is_provided": false,
"dir_train": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new",
"dir_train": "/home/vahid/extracted_lines/1919_bin/train",
"dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new",
"dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new",
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
"dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin",
"characters_txt_file":"dir_of_characters_txt_file_for_ocr"
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
}