mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: fix+simplify load_model logic for continue_training
- add missing combination `transformer` (w/ patch encoder and `weighted_loss`) - add assertion to prevent wrong loss type being configured
This commit is contained in:
parent
1581094141
commit
7562317da5
1 changed files with 13 additions and 23 deletions
|
|
@ -290,30 +290,20 @@ def run(_config,
|
||||||
weights = weights / float(np.min(weights))
|
weights = weights / float(np.min(weights))
|
||||||
weights = weights / float(np.sum(weights))
|
weights = weights / float(np.sum(weights))
|
||||||
|
|
||||||
|
if task == "enhancement":
|
||||||
|
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
|
||||||
|
assert not weighted_dice, "for enhancement, weighted loss does not apply"
|
||||||
if continue_training:
|
if continue_training:
|
||||||
if backbone_type == 'nontransformer':
|
custom_objects = dict()
|
||||||
if is_loss_soft_dice and task in ["segmentation", "binarization"]:
|
if is_loss_soft_dice:
|
||||||
model = load_model(dir_of_start_model, compile=True,
|
custom_objects.update(soft_dice_loss=soft_dice_loss)
|
||||||
custom_objects={'soft_dice_loss': soft_dice_loss})
|
elif weighted_loss:
|
||||||
elif weighted_loss and task in ["segmentation", "binarization"]:
|
custom_objects.update(loss=weighted_categorical_crossentropy(weights))
|
||||||
model = load_model(dir_of_start_model, compile=True,
|
if backbone_type == 'transformer':
|
||||||
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
custom_objects.update(PatchEncoder=PatchEncoder,
|
||||||
else:
|
Patches=Patches)
|
||||||
model = load_model(dir_of_start_model , compile=True)
|
model = load_model(dir_of_start_model, compile=False,
|
||||||
|
custom_objects=custom_objects)
|
||||||
elif backbone_type == 'transformer':
|
|
||||||
if is_loss_soft_dice and task in ["segmentation", "binarization"]:
|
|
||||||
model = load_model(dir_of_start_model, compile=True,
|
|
||||||
custom_objects={"PatchEncoder": PatchEncoder,
|
|
||||||
"Patches": Patches,
|
|
||||||
'soft_dice_loss': soft_dice_loss})
|
|
||||||
elif weighted_loss and task in ["segmentation", "binarization"]:
|
|
||||||
model = load_model(dir_of_start_model, compile=True,
|
|
||||||
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
|
||||||
else:
|
|
||||||
model = load_model(dir_of_start_model, compile=True,
|
|
||||||
custom_objects = {"PatchEncoder": PatchEncoder,
|
|
||||||
"Patches": Patches})
|
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
if backbone_type == 'nontransformer':
|
if backbone_type == 'nontransformer':
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue