mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: fix+simplify load_model logic for continue_training
- add missing combination `transformer` (w/ patch encoder and `weighted_loss`) - add assertion to prevent wrong loss type being configured
This commit is contained in:
parent
1581094141
commit
7562317da5
1 changed files with 13 additions and 23 deletions
|
|
@ -290,30 +290,20 @@ def run(_config,
|
|||
weights = weights / float(np.min(weights))
|
||||
weights = weights / float(np.sum(weights))
|
||||
|
||||
if task == "enhancement":
|
||||
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
|
||||
assert not weighted_dice, "for enhancement, weighted loss does not apply"
|
||||
if continue_training:
|
||||
if backbone_type == 'nontransformer':
|
||||
if is_loss_soft_dice and task in ["segmentation", "binarization"]:
|
||||
model = load_model(dir_of_start_model, compile=True,
|
||||
custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
elif weighted_loss and task in ["segmentation", "binarization"]:
|
||||
model = load_model(dir_of_start_model, compile=True,
|
||||
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
else:
|
||||
model = load_model(dir_of_start_model , compile=True)
|
||||
|
||||
elif backbone_type == 'transformer':
|
||||
if is_loss_soft_dice and task in ["segmentation", "binarization"]:
|
||||
model = load_model(dir_of_start_model, compile=True,
|
||||
custom_objects={"PatchEncoder": PatchEncoder,
|
||||
"Patches": Patches,
|
||||
'soft_dice_loss': soft_dice_loss})
|
||||
elif weighted_loss and task in ["segmentation", "binarization"]:
|
||||
model = load_model(dir_of_start_model, compile=True,
|
||||
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
else:
|
||||
model = load_model(dir_of_start_model, compile=True,
|
||||
custom_objects = {"PatchEncoder": PatchEncoder,
|
||||
"Patches": Patches})
|
||||
custom_objects = dict()
|
||||
if is_loss_soft_dice:
|
||||
custom_objects.update(soft_dice_loss=soft_dice_loss)
|
||||
elif weighted_loss:
|
||||
custom_objects.update(loss=weighted_categorical_crossentropy(weights))
|
||||
if backbone_type == 'transformer':
|
||||
custom_objects.update(PatchEncoder=PatchEncoder,
|
||||
Patches=Patches)
|
||||
model = load_model(dir_of_start_model, compile=False,
|
||||
custom_objects=custom_objects)
|
||||
else:
|
||||
index_start = 0
|
||||
if backbone_type == 'nontransformer':
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue