diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 168884a..7ede551 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -290,30 +290,20 @@ def run(_config, weights = weights / float(np.min(weights)) weights = weights / float(np.sum(weights)) + if task == "enhancement": + assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply" + assert not weighted_dice, "for enhancement, weighted loss does not apply" if continue_training: - if backbone_type == 'nontransformer': - if is_loss_soft_dice and task in ["segmentation", "binarization"]: - model = load_model(dir_of_start_model, compile=True, - custom_objects={'soft_dice_loss': soft_dice_loss}) - elif weighted_loss and task in ["segmentation", "binarization"]: - model = load_model(dir_of_start_model, compile=True, - custom_objects={'loss': weighted_categorical_crossentropy(weights)}) - else: - model = load_model(dir_of_start_model , compile=True) - - elif backbone_type == 'transformer': - if is_loss_soft_dice and task in ["segmentation", "binarization"]: - model = load_model(dir_of_start_model, compile=True, - custom_objects={"PatchEncoder": PatchEncoder, - "Patches": Patches, - 'soft_dice_loss': soft_dice_loss}) - elif weighted_loss and task in ["segmentation", "binarization"]: - model = load_model(dir_of_start_model, compile=True, - custom_objects={'loss': weighted_categorical_crossentropy(weights)}) - else: - model = load_model(dir_of_start_model, compile=True, - custom_objects = {"PatchEncoder": PatchEncoder, - "Patches": Patches}) + custom_objects = dict() + if is_loss_soft_dice: + custom_objects.update(soft_dice_loss=soft_dice_loss) + elif weighted_loss: + custom_objects.update(loss=weighted_categorical_crossentropy(weights)) + if backbone_type == 'transformer': + custom_objects.update(PatchEncoder=PatchEncoder, + Patches=Patches) + model = load_model(dir_of_start_model, compile=False, + custom_objects=custom_objects) else: index_start = 0 if backbone_type == 'nontransformer':