diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index fdc5437..ff714b4 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -758,3 +758,86 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224 model = Model(img_input , o) return model + +def cnn_rnn_ocr_model(image_height, image_width, n_classes, max_seq): + input_img = tensorflow.keras.Input(shape=(image_height, image_width, 3), name="image") + labels = tensorflow.keras.layers.Input(name="label", shape=(None,)) + + x = tensorflow.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(input_img) + x = tensorflow.keras.layers.BatchNormalization(name="bn1")(x) + x = tensorflow.keras.layers.Activation("relu", name="relu1")(x) + x = tensorflow.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(x) + x = tensorflow.keras.layers.BatchNormalization(name="bn2")(x) + x = tensorflow.keras.layers.Activation("relu", name="relu2")(x) + x = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x) + + x = tensorflow.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x) + x = tensorflow.keras.layers.BatchNormalization(name="bn3")(x) + x = tensorflow.keras.layers.Activation("relu", name="relu3")(x) + x = tensorflow.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x) + x = tensorflow.keras.layers.BatchNormalization(name="bn4")(x) + x = tensorflow.keras.layers.Activation("relu", name="relu4")(x) + x = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x) + + x = tensorflow.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x) + x = tensorflow.keras.layers.BatchNormalization(name="bn5")(x) + x = tensorflow.keras.layers.Activation("relu", name="relu5")(x) + x = tensorflow.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x) + x = tensorflow.keras.layers.BatchNormalization(name="bn6")(x) + x = tensorflow.keras.layers.Activation("relu", name="relu6")(x) + x = tensorflow.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x) + + x = tensorflow.keras.layers.Conv2D(512,kernel_size=(3,3),padding="same")(x) + x = tensorflow.keras.layers.BatchNormalization(name="bn7")(x) + x = tensorflow.keras.layers.Activation("relu", name="relu7")(x) + x = tensorflow.keras.layers.Conv2D(512,kernel_size=(16,1))(x) + x = tensorflow.keras.layers.BatchNormalization(name="bn8")(x) + x = tensorflow.keras.layers.Activation("relu", name="relu8")(x) + x2d = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x) + x4d = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d) + + + new_shape = (x.shape[2], x.shape[3]) + new_shape2 = (x2d.shape[2], x2d.shape[3]) + new_shape4 = (x4d.shape[2], x4d.shape[3]) + + + x = tensorflow.keras.layers.Reshape(target_shape=new_shape, name="reshape")(x) + x2d = tensorflow.keras.layers.Reshape(target_shape=new_shape2, name="reshape2")(x2d) + x4d = tensorflow.keras.layers.Reshape(target_shape=new_shape4, name="reshape4")(x4d) + + + xrnnorg = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x) + xrnn2d = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x2d) + xrnn4d = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x4d) + + xrnn2d = tensorflow.keras.layers.Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d) + xrnn4d = tensorflow.keras.layers.Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d) + + + xrnn2dup = tensorflow.keras.layers.UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d) + xrnn4dup = tensorflow.keras.layers.UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d) + + xrnn2dup = tensorflow.keras.layers.Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup) + xrnn4dup = tensorflow.keras.layers.Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup) + + addition = tensorflow.keras.layers.Add()([xrnnorg, xrnn2dup, xrnn4dup]) + + addition_rnn = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(addition) + + out = tensorflow.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn) + out = tensorflow.keras.layers.BatchNormalization(name="bn9")(out) + out = tensorflow.keras.layers.Activation("relu", name="relu9")(out) + #out = tensorflow.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out) + + out = tensorflow.keras.layers.Dense( + n_classes, activation="softmax", name="dense2" + )(out) + + # Add CTC layer for calculating CTC loss at each step. + output = CTCLayer(name="ctc_loss")(labels, out) + + model = tensorflow.keras.models.Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer") + + return model + diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 97736e0..b701ad1 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -101,6 +101,20 @@ def config_params(): degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. + image_inversion = False # If true, and if the binarized images are avilable the image inevrsion will be applied. + white_noise_strap = False # If true, white noise will be applied on some straps on the textline image. + textline_skewing = False # If true, textline images will be skewed for augmentation. + textline_skewing_bin = False # If true, textline image skewing augmentation for binarized images will be applied if already are available. + textline_left_in_depth = False # If true, left side of textline image will be displayed in depth. + textline_left_in_depth_bin = False # If true, left side of textline binarized image (if available) will be displayed in depth. + textline_right_in_depth = False # If true, right side of textline image will be displayed in depth. + textline_right_in_depth_bin = False # If true, right side of textline binarized image (if available) will be displayed in depth. + textline_up_in_depth = False # If true, upper side of textline image will be displayed in depth. + textline_up_in_depth_bin = False # If true, upper side of textline binarized image (if available) will be displayed in depth. + textline_down_in_depth = False # If true, lower side of textline image will be displayed in depth. + textline_down_in_depth_bin = False # If true, lower side of textline binarized image (if available) will be displayed in depth. + pepper_bin_aug = False # If true, pepper noise will be added to textline binarized image (if available). + pepper_aug = False # If true, pepper noise will be added to textline image. adding_rgb_background = False adding_rgb_foreground = False add_red_textlines = False @@ -111,7 +125,9 @@ def config_params(): pretraining = False # Set to true to load pretrained weights of ResNet50 encoder. scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image. scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image. + bin_deg = False # If true, a combination of degrading and binarization will be applied to the image. rotation = False # If true, a 90 degree rotation will be implemeneted. + color_padding_rotation = False # If true, rotation and padding will be implemeneted. rotation_not_90 = False # If true rotation based on provided angles with thetha will be implemeneted. scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image. scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. @@ -119,6 +135,7 @@ def config_params(): shuffle_indexes = None blur_k = None # Blur image for augmentation. scales = None # Scale patches for augmentation. + padd_colors = None # padding colors. A list elements can be only white and black. like ["white", "black"] or only one of them ["white"] degrade_scales = None # Degrade image for augmentation. brightness = None # Brighten image for augmentation. flip_index = None # Flip image for augmentation. @@ -145,6 +162,7 @@ def config_params(): number_of_backgrounds_per_image = 1 dir_rgb_backgrounds = None dir_rgb_foregrounds = None + characters_txt_file = None # Directory of characters text file needed for cnn_rnn_ocr model training. The file ends with .txt @ex.automain def run(_config, n_classes, n_epochs, input_height, @@ -159,7 +177,10 @@ def run(_config, n_classes, n_epochs, input_height, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output, - pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, + dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin, + textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, + textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors): if dir_rgb_backgrounds: list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) @@ -375,6 +396,34 @@ def run(_config, n_classes, n_epochs, input_height, #os.system('rm -rf '+dir_eval_flowing) #model.save(dir_output+'/'+'model'+'.h5') + + elif task=="cnn-rnn-ocr": + dir_img, dir_lab = get_dirs_or_files(dir_train) + + with open(characters_txt_file, 'r') as char_txt_f: + characters = json.load(char_txt_f) + + AUTOTUNE = tf.data.AUTOTUNE + + # Mapping characters to integers. + char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) + + # Mapping integers back to original characters. + num_to_char = StringLookup( + vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True + ) + + padding_token = len(characters) + 5 + ls_files_images = os.listdir(dir_img) + + train_ds = data_gen_ocr(padding_token, batchsize=n_batch, height=input_height, width=input_width, max_len=max_len, dir_ins=dir_train, ls_files_images, + augmentation, color_padding_rotation, rotation=rotation_not_90, bluring_aug=blurring, degrading, bin_deg, brightening, w_padding=padding_white, + rgb_fging=adding_rgb_foreground, rgb_bkding=adding_rgb_background, binarization, image_inversion, channel_shuffling=channels_shuffling, add_red_textline=add_red_textlines, white_noise_strap, + textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, + textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, + pepper_bin_aug, pepper_aug, deg_scales=degrade_scales, number_of_backgrounds_per_image, thethas=thetha, brightness, padd_colors, + shuffle_indexes, pepper_indexes, skewing_amplitudes) + elif task=='classification': configuration() model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining) diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index 1278be5..fed3e47 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -10,7 +10,213 @@ from scipy.ndimage.filters import gaussian_filter from tqdm import tqdm import imutils from tensorflow.keras.utils import to_categorical -from PIL import Image, ImageEnhance +from PIL import Image, ImageFile, ImageEnhance + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +def add_salt_and_pepper_noise(img, salt_prob, pepper_prob): + """ + Add salt-and-pepper noise to an image. + + Parameters: + image: ndarray + Input image. + salt_prob: float + Probability of salt noise. + pepper_prob: float + Probability of pepper noise. + + Returns: + noisy_image: ndarray + Image with salt-and-pepper noise. + """ + # Make a copy of the image + noisy_image = np.copy(img) + + # Generate random noise + total_pixels = img.size + num_salt = int(salt_prob * total_pixels) + num_pepper = int(pepper_prob * total_pixels) + + # Add salt noise + coords = [np.random.randint(0, i - 1, num_salt) for i in img.shape[:2]] + noisy_image[coords[0], coords[1]] = 255 # white pixels + + # Add pepper noise + coords = [np.random.randint(0, i - 1, num_pepper) for i in img.shape[:2]] + noisy_image[coords[0], coords[1]] = 0 # black pixels + + return noisy_image + +def invert_image(img): + img_inv = 255 - img + return img_inv + +def return_image_with_strapped_white_noises(img): + img_w_noised = np.copy(img) + img_h, img_width = img.shape[0], img.shape[1] + n = 9 + p = 0.3 + num_windows = np.random.binomial(n, p, 1)[0] + + if num_windows<1: + num_windows = 1 + + loc_of_windows = np.random.uniform(0,img_width,num_windows).astype(np.int64) + width_windows = np.random.uniform(10,50,num_windows).astype(np.int64) + + for i, loc in enumerate(loc_of_windows): + noise = np.random.normal(0, 50, (img_h, width_windows[i], 3)) + + try: + img_w_noised[:, loc:loc+width_windows[i], : ] = noise[:,:,:] + except: + pass + return img_w_noised + +def do_padding_for_ocr(img, percent_height, padding_color): + padding_size = int( img.shape[0]*percent_height/2. ) + height_new = img.shape[0] + 2*padding_size + width_new = img.shape[1] + 2*padding_size + + h_start = padding_size + w_start = padding_size + + if padding_color == 'white': + img_new = np.ones((height_new, width_new, img.shape[2])).astype(float) * 255 + if padding_color == 'black': + img_new = np.zeros((height_new, width_new, img.shape[2])).astype(float) + + img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :]) + + + return img_new + +def do_deskewing(img, amplitude): + height, width = img.shape[:2] + + # Generate sinusoidal wave distortion with reduced amplitude + #amplitude = 8 # 5 # Reduce the amplitude for less curvature + frequency = 300 # Increase frequency to stretch the curve + x_indices = np.tile(np.arange(width), (height, 1)) + y_indices = np.arange(height).reshape(-1, 1) + amplitude * np.sin(2 * np.pi * x_indices / frequency) + + # Convert indices to float32 for remapping + map_x = x_indices.astype(np.float32) + map_y = y_indices.astype(np.float32) + + # Apply the remap to create the curve + curved_image = cv2.remap(img, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) + return curved_image + +def do_left_in_depth(img): + height, width = img.shape[:2] + + # Define the original corner points of the image + src_points = np.float32([ + [0, 0], # Top-left corner + [width, 0], # Top-right corner + [0, height], # Bottom-left corner + [width, height] # Bottom-right corner + ]) + + # Define the new corner points for a subtle right-to-left tilt + dst_points = np.float32([ + [2, 13], # Slight inward shift for top-left + [width, 0], # Slight downward shift for top-right + [2, height-13], # Slight inward shift for bottom-left + [width, height] # Slight upward shift for bottom-right + ]) + + # Compute the perspective transformation matrix + matrix = cv2.getPerspectiveTransform(src_points, dst_points) + + # Apply the perspective warp + warped_image = cv2.warpPerspective(img, matrix, (width, height)) + return warped_image + +def do_right_in_depth(img): + height, width = img.shape[:2] + + # Define the original corner points of the image + src_points = np.float32([ + [0, 0], # Top-left corner + [width, 0], # Top-right corner + [0, height], # Bottom-left corner + [width, height] # Bottom-right corner + ]) + + # Define the new corner points for a subtle right-to-left tilt + dst_points = np.float32([ + [0, 0], # Slight inward shift for top-left + [width, 13], # Slight downward shift for top-right + [0, height], # Slight inward shift for bottom-left + [width, height - 13] # Slight upward shift for bottom-right + ]) + + # Compute the perspective transformation matrix + matrix = cv2.getPerspectiveTransform(src_points, dst_points) + + # Apply the perspective warp + warped_image = cv2.warpPerspective(img, matrix, (width, height)) + return warped_image + +def do_up_in_depth(img): + # Get the dimensions of the image + height, width = img.shape[:2] + + # Define the original corner points of the image + src_points = np.float32([ + [0, 0], # Top-left corner + [width, 0], # Top-right corner + [0, height], # Bottom-left corner + [width, height] # Bottom-right corner + ]) + + # Define the new corner points to simulate a tilted perspective + # Make the top part appear closer and the bottom part farther + dst_points = np.float32([ + [50, 0], # Top-left moved inward + [width - 50, 0], # Top-right moved inward + [0, height], # Bottom-left remains the same + [width, height] # Bottom-right remains the same + ]) + + # Compute the perspective transformation matrix + matrix = cv2.getPerspectiveTransform(src_points, dst_points) + + # Apply the perspective warp + warped_image = cv2.warpPerspective(img, matrix, (width, height)) + return warped_image + + +def do_down_in_depth(img): + # Get the dimensions of the image + height, width = img.shape[:2] + + # Define the original corner points of the image + src_points = np.float32([ + [0, 0], # Top-left corner + [width, 0], # Top-right corner + [0, height], # Bottom-left corner + [width, height] # Bottom-right corner + ]) + + # Define the new corner points to simulate a tilted perspective + # Make the top part appear closer and the bottom part farther + dst_points = np.float32([ + [0, 0], # Top-left moved inward + [width, 0], # Top-right moved inward + [50, height], # Bottom-left remains the same + [width - 50, height] # Bottom-right remains the same + ]) + + # Compute the perspective transformation matrix + matrix = cv2.getPerspectiveTransform(src_points, dst_points) + + # Apply the perspective warp + warped_image = cv2.warpPerspective(img, matrix, (width, height)) + return warped_image def return_shuffled_channels(img, channels_order): @@ -1055,3 +1261,620 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow cv2.flip( cv2.imread(dir_img + '/' + im), f_i), cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width, indexer=indexer, scaler=sc_ind) + + + +def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len=None, dir_ins=None, ls_files_images=None, + augmentation=False, color_padding_rotation=False, rotation=False, bluring_aug=False, degrading=False, bin_deg=False, brightening=False, w_padding=False, + rgb_fging=False, rgb_bkding=False, binarization=False, image_inversion=False, channel_shuffling=False, add_red_textline=False, white_noise_strap=False, + textline_skewing=False, textline_skewing_bin=False, textline_left_in_depth=False, textline_left_in_depth_bin=False, textline_right_in_depth=False, + textline_right_in_depth_bin=False, textline_up_in_depth=False, textline_up_in_depth_bin=False, textline_down_in_depth=False, textline_down_in_depth_bin=False, + pepper_bin_aug=False, pepper_aug=False, deg_scales=None, number_of_backgrounds_per_image=None, thethas=None, brightness=None, padd_colors=None, + shuffle_indexes=None, ): + + random.shuffle(ls_files_images) + + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + while True: + for i in ls_files_images: + f_name = i.split('.')[0] + + txt_inp = open(os.path.join(dir_ins, "labels/"+f_name+'.txt'),'r').read().split('\n')[0] + + img = cv2.imread(os.path.join(dir_ins, "images/"+i) ) + img_bin_corr = cv2.imread(os.path.join(dir_ins, "images_bin/"+f_name+'.png') ) + + + if augmentation: + img_out = scale_padd_image_for_ocr(img, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if color_padding_rotation: + for index, thetha in enumerate(thetha_padd): + for padd_col in padd_colors: + img_out = rotation_not_90_func(do_padding(img, 1.2, padd_col), thetha) + + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if rotation: + for index, thetha in enumerate(thethas): + img_out = rotation_not_90_func(img, thetha) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if bluring_aug: + for index, blur_type in enumerate(blurs): + img_out = bluring(img, blur_type) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if degrading: + for index, deg_scale_ind in enumerate(deg_scales): + try: + img_out = do_degrading(img, deg_scale_ind) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if bin_deg: + for index, deg_scale_ind in enumerate(deg_scales): + try: + img_out = do_degrading(img_bin_corr, deg_scale_ind) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if brightening: + for index, bright_scale_ind in enumerate(brightness): + try: + img_out = do_brightening(dir_img, bright_scale_ind) + except: + img_out = np.copy(img) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if w_padding: + for index, padding_size in enumerate(white_padds): + for padd_col in padd_colors: + img_out = do_padding(img, padding_size, padd_col) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if rgb_fging: + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs) + + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + foreground_rgb_chosen = np.load(dir_rgb_foregrounds + '/' + foreground_rgb_chosen_name) + + img_with_overlayed_background = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen) + + img_out = scale_padd_image_for_ocr(img_with_overlayed_background, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if rgb_bkding: + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) + + img_out = scale_padd_image_for_ocr(img_with_overlayed_background, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if binarization: + img_out = scale_padd_image_for_ocr(img_bin_corr, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if image_inversion: + img_out = invert_image(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :, :, :] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x = np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y = np.zeros((batch_size, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if channel_shuffling: + for shuffle_index in shuffle_indexes: + img_out = return_shuffled_channels(img, shuffle_index) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if add_red_textline: + img_red_context = return_image_with_red_elements(img, img_bin_corr) + + img_out = scale_padd_image_for_ocr(img_red_context, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if white_noise_strap: + img_out = return_image_with_strapped_white_noises(img) + + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if textline_skewing: + for index, des_scale_ind in enumerate(skewing_amplitudes): + try: + img_out = do_deskewing(img, des_scale_ind) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if textline_skewing_bin: + for index, des_scale_ind in enumerate(skewing_amplitudes): + try: + img_out = do_deskewing(img_bin_corr, des_scale_ind) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img_bin_corr) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if textline_left_in_depth: + try: + img_out = do_left_in_depth(img) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if textline_left_in_depth_bin: + try: + img_out = do_left_in_depth(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img_bin_corr) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if textline_right_in_depth: + try: + img_out = do_right_in_depth(img) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if textline_right_in_depth_bin: + try: + img_out = do_right_in_depth(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img_bin_corr) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if textline_up_in_depth: + try: + img_out = do_up_in_depth(img) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if textline_up_in_depth_bin: + try: + img_out = do_up_in_depth(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img_bin_corr) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if textline_down_in_depth: + try: + img_out = do_down_in_depth(img) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if textline_down_in_depth_bin: + try: + img_out = do_down_in_depth(img_bin_corr) + img_out = scale_padd_image_for_ocr(img_out, height, width) + except: + img_out = np.copy(img_bin_corr) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + if pepper_bin_aug: + for index, pepper_ind in enumerate(pepper_indexes): + img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + if pepper_aug: + for index, pepper_ind in enumerate(pepper_indexes): + img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind) + img_out = scale_padd_image_for_ocr(img_out, height, width) + + ret_x[batchcount, :,:,:] = img_out[:,:,:] + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + + + else: + + img_out = scale_padd_image_for_ocr(img, height, width) + ret_x[batchcount, :,:,:] = img_out[:,:,:] + + ret_y[batchcount, :] = vectorize_label(txt_inp) + + batchcount+=1 + + if batchcount>=batchsize: + ret_x = ret_x/255. + yield {"image": ret_x, "label": ret_y} + ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32) + ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token + batchcount = 0 + + +def return_muliplier_based_on_augmnentations(augmentation=False, color_padding_rotation=False, rotation=False, bluring_aug=False, + degrading=False, bin_deg=False, brightening=False, w_padding=False,rgb_fging=False, rgb_bkding=False, binarization=False, image_inversion=False, channel_shuffling=False, add_red_textline=False, white_noise_strap=False, + textline_skewing=False, textline_skewing_bin=False, textline_left_in_depth=False, textline_left_in_depth_bin=False, textline_right_in_depth=False, textline_right_in_depth_bin=False, textline_up_in_depth=False, textline_up_in_depth_bin=False, textline_down_in_depth=False, textline_down_in_depth_bin=False, pepper_bin_aug=False, pepper_aug=False, deg_scales=None, number_of_backgrounds_per_image=None, thethas=None, brightness=None, padd_colors=None): + aug_multip = 1 + + if augmentation: + if binarization: + aug_multip = aug_multip + 1 + if image_inversion: + aug_multip = aug_multip + 1 + if add_red_textline: + aug_multip = aug_multip + 1 + if white_noise_strap: + aug_multip = aug_multip + 1 + if textline_right_in_depth: + aug_multip = aug_multip + 1 + if textline_left_in_depth: + aug_multip = aug_multip + 1 + if textline_up_in_depth: + aug_multip = aug_multip + 1 + if textline_down_in_depth: + aug_multip = aug_multip + 1 + if textline_right_in_depth_bin: + aug_multip = aug_multip + 1 + if textline_left_in_depth_bin: + aug_multip = aug_multip + 1 + if textline_up_in_depth_bin: + aug_multip = aug_multip + 1 + if textline_down_in_depth_bin: + aug_multip = aug_multip + 1 + if rgb_fging: + aug_multip = aug_multip + number_of_backgrounds_per_image + if rgb_bkding: + aug_multip = aug_multip + number_of_backgrounds_per_image + if bin_deg: + aug_multip = aug_multip + len(deg_scales) + if degrading: + aug_multip = aug_multip + len(deg_scales) + if rotation: + aug_multip = aug_multip + len(thethas) + if textline_skewing: + aug_multip = aug_multip + len(skewing_amplitudes) + if textline_skewing_bin: + aug_multip = aug_multip + len(skewing_amplitudes) + if color_padding_rotation: + aug_multip = aug_multip + len(thetha_padd)*len(padd_colors) + if channel_shuffling: + aug_multip = aug_multip + len(shuffle_indexes) + if bluring_aug: + aug_multip = aug_multip + len(blurs) + if brightening: + aug_multip = aug_multip + len(brightness) + if w_padding: + aug_multip = aug_multip + len(white_padds)*len(padd_colors) + if pepper_aug: + aug_multip = aug_multip + len(pepper_indexes) + if pepper_bin_aug: + aug_multip = aug_multip + len(pepper_indexes) + + return aug_multip diff --git a/train/config_params.json b/train/config_params.json index 1db8026..1c94afc 100644 --- a/train/config_params.json +++ b/train/config_params.json @@ -2,6 +2,7 @@ "backbone_type" : "transformer", "task": "segmentation", "n_classes" : 2, + "max_len": 280, "n_epochs" : 0, "input_height" : 448, "input_width" : 448, @@ -34,7 +35,7 @@ "transformer_layers": 1, "transformer_num_heads": 1, "transformer_cnn_first": false, - "blur_k" : ["blur","guass","median"], + "blur_k" : ["blur","gauss","median"], "scales" : [0.6, 0.7, 0.8, 0.9], "brightness" : [1.3, 1.5, 1.7, 2], "degrade_scales" : [0.2, 0.4], @@ -53,6 +54,7 @@ "dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new", "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", "dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground", - "dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin" + "dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin", + "characters_txt_file":"dir_of_characters_txt_file_for_ocr" }