mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
adding enhancement training
This commit is contained in:
parent
dbb84507ed
commit
38db3e9289
5 changed files with 119 additions and 68 deletions
47
train.py
47
train.py
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
import tensorflow as tf
|
||||
from tensorflow.compat.v1.keras.backend import set_session
|
||||
import warnings
|
||||
|
@ -91,7 +92,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification):
|
||||
|
||||
if task == "segmentation":
|
||||
if task == "segmentation" or "enhancement":
|
||||
|
||||
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
||||
if data_is_provided:
|
||||
|
@ -153,7 +154,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation,
|
||||
patches=patches)
|
||||
|
||||
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
|
||||
|
@ -161,7 +162,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches)
|
||||
rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches)
|
||||
|
||||
if weighted_loss:
|
||||
weights = np.zeros(n_classes)
|
||||
|
@ -191,45 +192,49 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
if continue_training:
|
||||
if model_name=='resnet50_unet':
|
||||
if is_loss_soft_dice:
|
||||
if is_loss_soft_dice and task == "segmentation":
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
if weighted_loss and task == "segmentation":
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
if is_loss_soft_dice:
|
||||
if is_loss_soft_dice and task == "segmentation":
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
if weighted_loss and task == "segmentation":
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
else:
|
||||
index_start = 0
|
||||
if model_name=='resnet50_unet':
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining)
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if task == "segmentation":
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
elif task == "enhancement":
|
||||
model.compile(loss='mean_squared_error',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
||||
|
||||
# generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
||||
|
||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue