diff --git a/config_params.json b/config_params.json index a89cbb5..e5f652d 100644 --- a/config_params.json +++ b/config_params.json @@ -1,19 +1,22 @@ { "backbone_type" : "transformer", - "task": "binarization", + "task": "segmentation", "n_classes" : 2, - "n_epochs" : 2, - "input_height" : 224, - "input_width" : 224, + "n_epochs" : 0, + "input_height" : 448, + "input_width" : 448, "weight_decay" : 1e-6, "n_batch" : 1, "learning_rate": 1e-4, - "patches" : true, + "patches" : false, "pretraining" : true, - "augmentation" : false, + "augmentation" : true, "flip_aug" : false, "blur_aug" : false, "scaling" : true, + "adding_rgb_background": true, + "add_red_textlines": true, + "channels_shuffling": true, "degrading": false, "brightening": false, "binarization" : false, @@ -31,18 +34,23 @@ "transformer_num_heads": 1, "transformer_cnn_first": false, "blur_k" : ["blur","guass","median"], - "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "scales" : [0.6, 0.7, 0.8, 0.9], "brightness" : [1.3, 1.5, 1.7, 2], "degrade_scales" : [0.2, 0.4], "flip_index" : [0, 1, -1], - "thetha" : [10, -10], + "shuffle_indexes" : [ [0,2,1], [1,2,0], [1,0,2] , [2,1,0]], + "thetha" : [5, -5], + "number_of_backgrounds_per_image": 2, "continue_training": false, "index_start" : 0, "dir_of_start_model" : " ", "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "./train", - "dir_eval": "./eval", - "dir_output": "./output" + "dir_train": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new", + "dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new", + "dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new", + "dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background", + "dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin" + } diff --git a/train.py b/train.py index fa08a98..5dfad07 100644 --- a/train.py +++ b/train.py @@ -53,7 +53,9 @@ def config_params(): degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. - rgb_background = False + adding_rgb_background = False + add_red_textlines = False + channels_shuffling = False dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". dir_output = None # Directory where the output model will be saved. @@ -65,6 +67,7 @@ def config_params(): scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image. scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. thetha = None # Rotate image by these angles for augmentation. + shuffle_indexes = None blur_k = None # Blur image for augmentation. scales = None # Scale patches for augmentation. degrade_scales = None # Degrade image for augmentation. @@ -88,6 +91,10 @@ def config_params(): f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output. classification_classes_name = None # Dictionary of classification classes names. backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer" + + dir_img_bin = None + number_of_backgrounds_per_image = 1 + dir_rgb_backgrounds = None @ex.automain @@ -95,15 +102,20 @@ def run(_config, n_classes, n_epochs, input_height, input_width, weight_decay, weighted_loss, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, - blur_aug, padding_white, padding_black, scaling, degrading, - brightening, binarization, rgb_background, blur_k, scales, degrade_scales, + blur_aug, padding_white, padding_black, scaling, degrading,channels_shuffling, + brightening, binarization, adding_rgb_background, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, - pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds): + + if dir_rgb_backgrounds: + list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) + else: + list_all_possible_background_images = None if task == "segmentation" or task == "enhancement" or task == "binarization": if data_is_provided: @@ -163,18 +175,18 @@ def run(_config, n_classes, n_epochs, input_height, # writing patches into a sub-folder in order to be flowed from directory. provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, - blur_aug, padding_white, padding_black, flip_aug, binarization, rgb_background, + blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background,add_red_textlines, channels_shuffling, scaling, degrading, brightening, scales, degrade_scales, brightness, - flip_index, scaling_bluring, scaling_brightness, scaling_binarization, + flip_index,shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, - patches=patches) + patches=patches, dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds) provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, - blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, rgb_background, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background, add_red_textlines, channels_shuffling, scaling, degrading, brightening, scales, degrade_scales, brightness, - flip_index, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches) + flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches,dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds) if weighted_loss: weights = np.zeros(n_classes) diff --git a/utils.py b/utils.py index cf7a65c..20fda29 100644 --- a/utils.py +++ b/utils.py @@ -51,6 +51,16 @@ def return_binary_image_with_given_rgb_background_red_textlines(img_bin, img_rgb return img_final +def return_image_with_red_elements(img, img_bin): + img_final = np.copy(img) + + img_final[:,:,0][img_bin[:,:,0]==0] = 0 + img_final[:,:,1][img_bin[:,:,0]==0] = 0 + img_final[:,:,2][img_bin[:,:,0]==0] = 255 + return img_final + + + def scale_image_for_no_patch(img, label, scale): h_n = int(img.shape[0]*scale) w_n = int(img.shape[1]*scale) @@ -631,10 +641,10 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, - padding_white, padding_black, flip_aug, binarization, scaling, degrading, - brightening, scales, degrade_scales, brightness, flip_index, + padding_white, padding_black, flip_aug, binarization, adding_rgb_background, add_red_textlines, channels_shuffling, scaling, degrading, + brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False): + rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False, dir_img_bin=None,number_of_backgrounds_per_image=None,list_all_possible_background_images=None, dir_rgb_backgrounds=None): indexer = 0 for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): @@ -724,17 +734,29 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_scaled, input_height, input_width)) indexer += 1 - if rgb_color_background: + if adding_rgb_background: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') for i_n in range(number_of_backgrounds_per_image): background_image_chosen_name = random.choice(list_all_possible_background_images) img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) - img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background) + img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen) cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_with_overlayed_background, input_height, input_width)) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + + if add_red_textlines: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + img_red_context = return_image_with_red_elements(cv2.imread(dir_img + '/'+im), img_bin_corr) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_red_context, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + + indexer += 1 +