training: fix+simplify load_model logic for continue_training

- add missing combination `transformer` (w/ patch encoder and
  `weighted_loss`)
- add assertion to prevent wrong loss type being configured
This commit is contained in:
Robert Sachunsky 2026-02-04 17:35:38 +01:00
parent 1581094141
commit 7562317da5

View file

@ -290,30 +290,20 @@ def run(_config,
weights = weights / float(np.min(weights))
weights = weights / float(np.sum(weights))
if task == "enhancement":
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
assert not weighted_dice, "for enhancement, weighted loss does not apply"
if continue_training:
if backbone_type == 'nontransformer':
if is_loss_soft_dice and task in ["segmentation", "binarization"]:
model = load_model(dir_of_start_model, compile=True,
custom_objects={'soft_dice_loss': soft_dice_loss})
elif weighted_loss and task in ["segmentation", "binarization"]:
model = load_model(dir_of_start_model, compile=True,
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
else:
model = load_model(dir_of_start_model , compile=True)
elif backbone_type == 'transformer':
if is_loss_soft_dice and task in ["segmentation", "binarization"]:
model = load_model(dir_of_start_model, compile=True,
custom_objects={"PatchEncoder": PatchEncoder,
"Patches": Patches,
'soft_dice_loss': soft_dice_loss})
elif weighted_loss and task in ["segmentation", "binarization"]:
model = load_model(dir_of_start_model, compile=True,
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
else:
model = load_model(dir_of_start_model, compile=True,
custom_objects = {"PatchEncoder": PatchEncoder,
"Patches": Patches})
custom_objects = dict()
if is_loss_soft_dice:
custom_objects.update(soft_dice_loss=soft_dice_loss)
elif weighted_loss:
custom_objects.update(loss=weighted_categorical_crossentropy(weights))
if backbone_type == 'transformer':
custom_objects.update(PatchEncoder=PatchEncoder,
Patches=Patches)
model = load_model(dir_of_start_model, compile=False,
custom_objects=custom_objects)
else:
index_start = 0
if backbone_type == 'nontransformer':