diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 4652b07..d0b24c0 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -334,7 +334,7 @@ def vit_resnet50_unet(num_patches, transformer_mlp_head_units = [128, 64] inputs = Input(shape=(input_height, input_width, 3)) - features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining) + features = list(resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining)) features[-1] = transformer_block(features[-1], num_patches,