|
|
|
@ -96,7 +96,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|
|
|
|
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
|
|
|
|
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
|
|
|
|
|
|
|
|
|
if task == "segmentation" or task == "enhancement":
|
|
|
|
|
if task == "segmentation" or task == "enhancement" or task == "binarization":
|
|
|
|
|
if data_is_provided:
|
|
|
|
|
dir_train_flowing = os.path.join(dir_output, 'train')
|
|
|
|
|
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
|
|
|
@ -194,16 +194,16 @@ def run(_config, n_classes, n_epochs, input_height,
|
|
|
|
|
|
|
|
|
|
if continue_training:
|
|
|
|
|
if backbone_type=='nontransformer':
|
|
|
|
|
if is_loss_soft_dice and task == "segmentation":
|
|
|
|
|
if is_loss_soft_dice and (task == "segmentation" or task == "binarization"):
|
|
|
|
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
|
|
|
|
if weighted_loss and task == "segmentation":
|
|
|
|
|
if weighted_loss and (task == "segmentation" or task == "binarization"):
|
|
|
|
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
|
|
|
|
if not is_loss_soft_dice and not weighted_loss:
|
|
|
|
|
model = load_model(dir_of_start_model , compile=True)
|
|
|
|
|
elif backbone_type=='transformer':
|
|
|
|
|
if is_loss_soft_dice and task == "segmentation":
|
|
|
|
|
if is_loss_soft_dice and (task == "segmentation" or task == "binarization"):
|
|
|
|
|
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
|
|
|
|
if weighted_loss and task == "segmentation":
|
|
|
|
|
if weighted_loss and (task == "segmentation" or task == "binarization"):
|
|
|
|
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
|
|
|
|
if not is_loss_soft_dice and not weighted_loss:
|
|
|
|
|
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
|
|
|
@ -225,7 +225,8 @@ def run(_config, n_classes, n_epochs, input_height,
|
|
|
|
|
#if you want to see the model structure just uncomment model summary.
|
|
|
|
|
#model.summary()
|
|
|
|
|
|
|
|
|
|
if task == "segmentation":
|
|
|
|
|
|
|
|
|
|
if (task == "segmentation" or task == "binarization"):
|
|
|
|
|
if not is_loss_soft_dice and not weighted_loss:
|
|
|
|
|
model.compile(loss='categorical_crossentropy',
|
|
|
|
|
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
|
|
|
|