binarization as a separate task of segmentation

This commit is contained in:
vahidrezanezhad 2024-06-11 17:48:30 +02:00
parent 41a0e15e79
commit 2aa216e388
2 changed files with 9 additions and 8 deletions

View file

@ -96,7 +96,7 @@ def run(_config, n_classes, n_epochs, input_height,
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
if task == "segmentation" or task == "enhancement":
if task == "segmentation" or task == "enhancement" or task == "binarization":
if data_is_provided:
dir_train_flowing = os.path.join(dir_output, 'train')
dir_eval_flowing = os.path.join(dir_output, 'eval')
@ -194,16 +194,16 @@ def run(_config, n_classes, n_epochs, input_height,
if continue_training:
if backbone_type=='nontransformer':
if is_loss_soft_dice and task == "segmentation":
if is_loss_soft_dice and (task == "segmentation" or task == "binarization"):
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
if weighted_loss and task == "segmentation":
if weighted_loss and (task == "segmentation" or task == "binarization"):
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model , compile=True)
elif backbone_type=='transformer':
if is_loss_soft_dice and task == "segmentation":
if is_loss_soft_dice and (task == "segmentation" or task == "binarization"):
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
if weighted_loss and task == "segmentation":
if weighted_loss and (task == "segmentation" or task == "binarization"):
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
@ -224,8 +224,9 @@ def run(_config, n_classes, n_epochs, input_height,
#if you want to see the model structure just uncomment model summary.
#model.summary()
if task == "segmentation":
if (task == "segmentation" or task == "binarization"):
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])