From 95bbdf804058c255944554e1ad2ba608a5929fd2 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 21 Aug 2024 16:17:59 +0200 Subject: [PATCH] updating augmentations --- train.py | 8 +++++--- utils.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 71f31f3..fa08a98 100644 --- a/train.py +++ b/train.py @@ -53,6 +53,7 @@ def config_params(): degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. + rgb_background = False dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". dir_output = None # Directory where the output model will be saved. @@ -95,7 +96,7 @@ def run(_config, n_classes, n_epochs, input_height, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, blur_aug, padding_white, padding_black, scaling, degrading, - brightening, binarization, blur_k, scales, degrade_scales, + brightening, binarization, rgb_background, blur_k, scales, degrade_scales, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, continue_training, transformer_projection_dim, @@ -108,6 +109,7 @@ def run(_config, n_classes, n_epochs, input_height, if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') + dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images') dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels') @@ -161,7 +163,7 @@ def run(_config, n_classes, n_epochs, input_height, # writing patches into a sub-folder in order to be flowed from directory. provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, - blur_aug, padding_white, padding_black, flip_aug, binarization, + blur_aug, padding_white, padding_black, flip_aug, binarization, rgb_background, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, @@ -169,7 +171,7 @@ def run(_config, n_classes, n_epochs, input_height, provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, - blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, rgb_background, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches) diff --git a/utils.py b/utils.py index 2278849..cf7a65c 100644 --- a/utils.py +++ b/utils.py @@ -695,6 +695,47 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 + + if rotation_not_90: + for thetha_i in thetha: + img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), thetha_i) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_max_rotated, input_height, input_width)) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_max_rotated, input_height, input_width)) + indexer += 1 + + if channels_shuffling: + for shuffle_index in shuffle_indexes: + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', + (resize_image(return_shuffled_channels(cv2.imread(dir_img + '/' + im), shuffle_index), input_height, input_width))) + + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + indexer += 1 + + if scaling: + for sc_ind in scales: + img_scaled, label_scaled = scale_image_for_no_patch(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), sc_ind) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_scaled, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_scaled, input_height, input_width)) + indexer += 1 + + if rgb_color_background: + img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png') + for i_n in range(number_of_backgrounds_per_image): + background_image_chosen_name = random.choice(list_all_possible_background_images) + img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name) + img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_with_overlayed_background, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) + + if patches: