From 38db3e9289a1d5b2f17e26f7a857ee4030f3901e Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 6 May 2024 18:31:48 +0200 Subject: [PATCH] adding enhancement training --- config_params.json | 20 +++++------ gt_for_enhancement_creator.py | 31 ++++++++++++++++++ models.py | 27 ++++++++++----- train.py | 47 ++++++++++++++------------ utils.py | 62 +++++++++++++++++++---------------- 5 files changed, 119 insertions(+), 68 deletions(-) create mode 100644 gt_for_enhancement_creator.py diff --git a/config_params.json b/config_params.json index 43ad1bc..1c7a940 100644 --- a/config_params.json +++ b/config_params.json @@ -1,15 +1,15 @@ { "model_name" : "resnet50_unet", - "task": "classification", - "n_classes" : 2, - "n_epochs" : 7, - "input_height" : 224, - "input_width" : 224, + "task": "enhancement", + "n_classes" : 3, + "n_epochs" : 3, + "input_height" : 448, + "input_width" : 448, "weight_decay" : 1e-6, - "n_batch" : 6, + "n_batch" : 3, "learning_rate": 1e-4, "f1_threshold_classification": 0.8, - "patches" : false, + "patches" : true, "pretraining" : true, "augmentation" : false, "flip_aug" : false, @@ -35,7 +35,7 @@ "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "/home/vahid/Downloads/image_classification_data/train", - "dir_eval": "/home/vahid/Downloads/image_classification_data/eval", - "dir_output": "/home/vahid/Downloads/image_classification_data/output" + "dir_train": "./training_data_sample_enhancement", + "dir_eval": "./eval", + "dir_output": "./out" } diff --git a/gt_for_enhancement_creator.py b/gt_for_enhancement_creator.py new file mode 100644 index 0000000..9a4274f --- /dev/null +++ b/gt_for_enhancement_creator.py @@ -0,0 +1,31 @@ +import cv2 +import os + +def resize_image(seg_in, input_height, input_width): + return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) + + +dir_imgs = './training_data_sample_enhancement/images' +dir_out_imgs = './training_data_sample_enhancement/images_gt' +dir_out_labs = './training_data_sample_enhancement/labels_gt' + +ls_imgs = os.listdir(dir_imgs) + + +ls_scales = [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] + + +for img in ls_imgs: + img_name = img.split('.')[0] + img_type = img.split('.')[1] + image = cv2.imread(os.path.join(dir_imgs, img)) + for i, scale in enumerate(ls_scales): + height_sc = int(image.shape[0]*scale) + width_sc = int(image.shape[1]*scale) + + image_down_scaled = resize_image(image, height_sc, width_sc) + image_back_to_org_scale = resize_image(image_down_scaled, image.shape[0], image.shape[1]) + + cv2.imwrite(os.path.join(dir_out_imgs, img_name+'_'+str(i)+'.'+img_type), image_back_to_org_scale) + cv2.imwrite(os.path.join(dir_out_labs, img_name+'_'+str(i)+'.'+img_type), image) + diff --git a/models.py b/models.py index a6de1ef..4cceacd 100644 --- a/models.py +++ b/models.py @@ -168,7 +168,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) return x -def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False): +def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False): assert input_height % 32 == 0 assert input_width % 32 == 0 @@ -259,14 +259,17 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_dec o = Activation('relu')(o) o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) model = Model(img_input, o) return model -def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False): +def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): assert input_height % 32 == 0 assert input_width % 32 == 0 @@ -354,15 +357,18 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e- o = Activation('relu')(o) o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) model = Model(img_input, o) return model -def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): +def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): inputs = layers.Input(shape=(input_height, input_width, 3)) IMAGE_ORDERING = 'channels_last' bn_axis=3 @@ -465,8 +471,11 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_ o = Activation('relu')(o) o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) model = Model(inputs=inputs, outputs=o) diff --git a/train.py b/train.py index efcd3ac..595debe 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,6 @@ import os import sys +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf from tensorflow.compat.v1.keras.backend import set_session import warnings @@ -91,7 +92,7 @@ def run(_config, n_classes, n_epochs, input_height, num_patches_xy, model_name, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification): - if task == "segmentation": + if task == "segmentation" or "enhancement": num_patches = num_patches_xy[0]*num_patches_xy[1] if data_is_provided: @@ -153,7 +154,7 @@ def run(_config, n_classes, n_epochs, input_height, blur_aug, padding_white, padding_black, flip_aug, binarization, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation, + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation, patches=patches) provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, @@ -161,7 +162,7 @@ def run(_config, n_classes, n_epochs, input_height, blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, - rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches) + rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches) if weighted_loss: weights = np.zeros(n_classes) @@ -191,45 +192,49 @@ def run(_config, n_classes, n_epochs, input_height, if continue_training: if model_name=='resnet50_unet': - if is_loss_soft_dice: + if is_loss_soft_dice and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) - if weighted_loss: + if weighted_loss and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True) elif model_name=='hybrid_transformer_cnn': - if is_loss_soft_dice: + if is_loss_soft_dice and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) - if weighted_loss: + if weighted_loss and task == "segmentation": model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) else: index_start = 0 if model_name=='resnet50_unet': - model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining) elif model_name=='hybrid_transformer_cnn': - model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining) + model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary. #model.summary() - - if not is_loss_soft_dice and not weighted_loss: - model.compile(loss='categorical_crossentropy', - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if is_loss_soft_dice: - model.compile(loss=soft_dice_loss, - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if weighted_loss: - model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if task == "segmentation": + if not is_loss_soft_dice and not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if is_loss_soft_dice: + model.compile(loss=soft_dice_loss, + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + elif task == "enhancement": + model.compile(loss='mean_squared_error', + optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + # generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, - input_height=input_height, input_width=input_width, n_classes=n_classes) + input_height=input_height, input_width=input_width, n_classes=n_classes, task=task) val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, - input_height=input_height, input_width=input_width, n_classes=n_classes) + input_height=input_height, input_width=input_width, n_classes=n_classes, task=task) ##img_validation_patches = os.listdir(dir_flow_eval_imgs) ##score_best=[] diff --git a/utils.py b/utils.py index af3c5f8..0c5a458 100644 --- a/utils.py +++ b/utils.py @@ -268,7 +268,7 @@ def IoU(Yi, y_predi): return mIoU -def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes): +def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'): c = 0 n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images random.shuffle(n) @@ -277,8 +277,6 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float') for i in range(c, c + batch_size): # initially from 0 to 16, c = 0. - # print(img_folder+'/'+n[i]) - try: filename = n[i].split('.')[0] @@ -287,11 +285,14 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize img[i - c] = train_img # add to array - img[0], img[1], and so on. - train_mask = cv2.imread(mask_folder + '/' + filename + '.png') - # print(mask_folder+'/'+filename+'.png') - # print(train_mask.shape) - train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, - n_classes) + if task == "segmentation": + train_mask = cv2.imread(mask_folder + '/' + filename + '.png') + train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, + n_classes) + elif task == "enhancement": + train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255. + train_mask = resize_image(train_mask, input_height, input_width) + # train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] mask[i - c] = train_mask @@ -539,14 +540,19 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow padding_white, padding_black, flip_aug, binarization, scaling, degrading, brightening, scales, degrade_scales, brightness, flip_index, scaling_bluring, scaling_brightness, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, augmentation=False, patches=False): + rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False): indexer = 0 for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): img_name = im.split('.')[0] + if task == "segmentation": + dir_of_label_file = os.path.join(dir_seg, img_name + '.png') + elif task=="enhancement": + dir_of_label_file = os.path.join(dir_seg, im) + if not patches: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 if augmentation: @@ -556,7 +562,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width)) + resize_image(cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width)) indexer += 1 if blur_aug: @@ -565,7 +571,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width))) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 if binarization: @@ -573,26 +579,26 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width)) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + resize_image(cv2.imread(dir_of_label_file), input_height, input_width)) indexer += 1 if patches: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, - cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_img + '/' + im), cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) if augmentation: if rotation: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, rotation_90(cv2.imread(dir_img + '/' + im)), - rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')), + rotation_90(cv2.imread(dir_of_label_file)), input_height, input_width, indexer=indexer) if rotation_not_90: for thetha_i in thetha: img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im), - cv2.imread(dir_seg + '/'+img_name + '.png'), thetha_i) + cv2.imread(dir_of_label_file), thetha_i) indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, img_max_rotated, label_max_rotated, @@ -601,24 +607,24 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow for f_i in flip_index: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, cv2.flip(cv2.imread(dir_img + '/' + im), f_i), - cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), + cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width, indexer=indexer) if blur_aug: for blur_i in blur_k: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, bluring(cv2.imread(dir_img + '/' + im), blur_i), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) if padding_black: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, do_padding_black(cv2.imread(dir_img + '/' + im)), - do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')), + do_padding_label(cv2.imread(dir_of_label_file)), input_height, input_width, indexer=indexer) if padding_white: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, do_padding_white(cv2.imread(dir_img + '/'+im)), - do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')), + do_padding_label(cv2.imread(dir_of_label_file)), input_height, input_width, indexer=indexer) if brightening: @@ -626,7 +632,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow try: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, do_brightening(dir_img + '/' +im, factor), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) except: pass @@ -634,20 +640,20 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow for sc_ind in scales: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, cv2.imread(dir_img + '/' + im) , - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer, scaler=sc_ind) if degrading: for degrade_scale_ind in degrade_scales: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) if binarization: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, otsu_copy(cv2.imread(dir_img + '/' + im)), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer) if scaling_brightness: @@ -657,7 +663,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, do_brightening(dir_img + '/' + im, factor) - ,cv2.imread(dir_seg + '/' + img_name + '.png') + ,cv2.imread(dir_of_label_file) ,input_height, input_width, indexer=indexer, scaler=sc_ind) except: pass @@ -667,14 +673,14 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow for blur_i in blur_k: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, bluring(cv2.imread(dir_img + '/' + im), blur_i), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer, scaler=sc_ind) if scaling_binarization: for sc_ind in scales: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, otsu_copy(cv2.imread(dir_img + '/' + im)), - cv2.imread(dir_seg + '/' + img_name + '.png'), + cv2.imread(dir_of_label_file), input_height, input_width, indexer=indexer, scaler=sc_ind) if scaling_flip: @@ -682,5 +688,5 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow for f_i in flip_index: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, cv2.flip( cv2.imread(dir_img + '/' + im), f_i), - cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), + cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width, indexer=indexer, scaler=sc_ind)