diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index d0b24c0..13a35a1 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -345,9 +345,7 @@ def vit_resnet50_unet(num_patches, transformer_num_heads, transformer_projection_dim) - o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) - - return Model(inputs, o) + return unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) def vit_resnet50_unet_transformer_before_cnn(num_patches, n_classes, @@ -380,9 +378,7 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining) - o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) - - return Model(inputs, o) + return unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): include_top=True