adding enhancement training

This commit is contained in:
vahidrezanezhad 2024-05-06 18:31:48 +02:00
parent dbb84507ed
commit 38db3e9289
5 changed files with 119 additions and 68 deletions

View file

@ -1,5 +1,6 @@
import os
import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.compat.v1.keras.backend import set_session
import warnings
@ -91,7 +92,7 @@ def run(_config, n_classes, n_epochs, input_height,
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification):
if task == "segmentation":
if task == "segmentation" or "enhancement":
num_patches = num_patches_xy[0]*num_patches_xy[1]
if data_is_provided:
@ -153,7 +154,7 @@ def run(_config, n_classes, n_epochs, input_height,
blur_aug, padding_white, padding_black, flip_aug, binarization,
scaling, degrading, brightening, scales, degrade_scales, brightness,
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation,
rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation,
patches=patches)
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
@ -161,7 +162,7 @@ def run(_config, n_classes, n_epochs, input_height,
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
scaling, degrading, brightening, scales, degrade_scales, brightness,
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches)
rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches)
if weighted_loss:
weights = np.zeros(n_classes)
@ -191,45 +192,49 @@ def run(_config, n_classes, n_epochs, input_height,
if continue_training:
if model_name=='resnet50_unet':
if is_loss_soft_dice:
if is_loss_soft_dice and task == "segmentation":
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
if weighted_loss:
if weighted_loss and task == "segmentation":
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model , compile=True)
elif model_name=='hybrid_transformer_cnn':
if is_loss_soft_dice:
if is_loss_soft_dice and task == "segmentation":
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
if weighted_loss:
if weighted_loss and task == "segmentation":
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
else:
index_start = 0
if model_name=='resnet50_unet':
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
elif model_name=='hybrid_transformer_cnn':
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining)
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
#if you want to see the model structure just uncomment model summary.
#model.summary()
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
if is_loss_soft_dice:
model.compile(loss=soft_dice_loss,
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
if task == "segmentation":
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
if is_loss_soft_dice:
model.compile(loss=soft_dice_loss,
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
elif task == "enhancement":
model.compile(loss='mean_squared_error',
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
# generating train and evaluation data
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
input_height=input_height, input_width=input_width, n_classes=n_classes)
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
input_height=input_height, input_width=input_width, n_classes=n_classes)
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
##score_best=[]