From 7b4d14b19f536614545b209bf3834b6b84a67d1d Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 29 Oct 2024 17:06:22 +0100 Subject: [PATCH] addinh shifting augmentation --- train/train.py | 7 ++++--- train/utils.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/train/train.py b/train/train.py index 848ff6a..7e3e390 100644 --- a/train/train.py +++ b/train/train.py @@ -50,6 +50,7 @@ def config_params(): padding_white = False # If true, white padding will be applied to the image. padding_black = False # If true, black padding will be applied to the image. scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json. + shifting = False degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. @@ -104,7 +105,7 @@ def run(_config, n_classes, n_epochs, input_height, input_width, weight_decay, weighted_loss, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, - blur_aug, padding_white, padding_black, scaling, degrading,channels_shuffling, + blur_aug, padding_white, padding_black, scaling, shifting, degrading,channels_shuffling, brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, @@ -183,7 +184,7 @@ def run(_config, n_classes, n_epochs, input_height, provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background,adding_rgb_foreground, add_red_textlines, channels_shuffling, - scaling, degrading, brightening, scales, degrade_scales, brightness, + scaling, shifting, degrading, brightening, scales, degrade_scales, brightness, flip_index,shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, patches=patches, dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds, dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs) @@ -191,7 +192,7 @@ def run(_config, n_classes, n_epochs, input_height, provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, - scaling, degrading, brightening, scales, degrade_scales, brightness, + scaling, shifting, degrading, brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches,dir_img_bin=dir_img_bin,number_of_backgrounds_per_image=number_of_backgrounds_per_image,list_all_possible_background_images=list_all_possible_background_images, dir_rgb_backgrounds=dir_rgb_backgrounds,dir_rgb_foregrounds=dir_rgb_foregrounds,list_all_possible_foreground_rgbs=list_all_possible_foreground_rgbs ) diff --git a/train/utils.py b/train/utils.py index 3d42b64..d7ddb99 100644 --- a/train/utils.py +++ b/train/utils.py @@ -78,7 +78,50 @@ def return_image_with_red_elements(img, img_bin): img_final[:,:,2][img_bin[:,:,0]==0] = 255 return img_final +def shift_image_and_label(img, label, type_shift): + h_n = int(img.shape[0]*1.06) + w_n = int(img.shape[1]*1.06) + + channel0_avg = int( np.mean(img[:,:,0]) ) + channel1_avg = int( np.mean(img[:,:,1]) ) + channel2_avg = int( np.mean(img[:,:,2]) ) + h_diff = abs( img.shape[0] - h_n ) + w_diff = abs( img.shape[1] - w_n ) + + h_start = int(h_diff / 2.) + w_start = int(w_diff / 2.) + + img_scaled_padded = np.zeros((h_n, w_n, 3)) + label_scaled_padded = np.zeros((h_n, w_n, 3)) + + img_scaled_padded[:,:,0] = channel0_avg + img_scaled_padded[:,:,1] = channel1_avg + img_scaled_padded[:,:,2] = channel2_avg + + img_scaled_padded[h_start:h_start+img.shape[0], w_start:w_start+img.shape[1],:] = img[:,:,:] + label_scaled_padded[h_start:h_start+img.shape[0], w_start:w_start+img.shape[1],:] = label[:,:,:] + + + if type_shift=="xpos": + img_dis = img_scaled_padded[h_start:h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + label_dis = label_scaled_padded[h_start:h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + elif type_shift=="xmin": + img_dis = img_scaled_padded[h_start:h_start+img.shape[0],:img.shape[1],:] + label_dis = label_scaled_padded[h_start:h_start+img.shape[0],:img.shape[1],:] + elif type_shift=="ypos": + img_dis = img_scaled_padded[2*h_start:2*h_start+img.shape[0],w_start:w_start+img.shape[1],:] + label_dis = label_scaled_padded[2*h_start:2*h_start+img.shape[0],w_start:w_start+img.shape[1],:] + elif type_shift=="ymin": + img_dis = img_scaled_padded[:img.shape[0],w_start:w_start+img.shape[1],:] + label_dis = label_scaled_padded[:img.shape[0],w_start:w_start+img.shape[1],:] + elif type_shift=="xypos": + img_dis = img_scaled_padded[2*h_start:2*h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + label_dis = label_scaled_padded[2*h_start:2*h_start+img.shape[0],2*w_start:2*w_start+img.shape[1],:] + elif type_shift=="xymin": + img_dis = img_scaled_padded[:img.shape[0],:img.shape[1],:] + label_dis = label_scaled_padded[:img.shape[0],:img.shape[1],:] + return img_dis, label_dis def scale_image_for_no_patch(img, label, scale): h_n = int(img.shape[0]*scale) @@ -660,7 +703,7 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs, dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, - padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, scaling, degrading, + padding_white, padding_black, flip_aug, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, channels_shuffling, scaling, shifting, degrading, brightening, scales, degrade_scales, brightness, flip_index, shuffle_indexes, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False, dir_img_bin=None,number_of_backgrounds_per_image=None,list_all_possible_background_images=None, dir_rgb_backgrounds=None, dir_rgb_foregrounds=None, list_all_possible_foreground_rgbs=None): @@ -759,6 +802,16 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_scaled, input_height, input_width)) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_scaled, input_height, input_width)) indexer += 1 + if shifting: + shift_types = ['xpos', 'xmin', 'ypos', 'ymin', 'xypos', 'xymin'] + for st_ind in shift_types: + img_shifted, label_shifted = shift_image_and_label(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_of_label_file), st_ind) + + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(img_shifted, input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(label_shifted, input_height, input_width)) + indexer += 1 + if adding_rgb_background: img_bin_corr = cv2.imread(dir_img_bin + '/' + img_name+'.png')