diff --git a/train/models.py b/train/models.py index 8841bd3..fdc5437 100644 --- a/train/models.py +++ b/train/models.py @@ -394,7 +394,9 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati return model -def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): +def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + if mlp_head_units is None: + mlp_head_units = [128, 64] inputs = layers.Input(shape=(input_height, input_width, 3)) #transformer_units = [ @@ -516,7 +518,9 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he return model -def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): +def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + if mlp_head_units is None: + mlp_head_units = [128, 64] inputs = layers.Input(shape=(input_height, input_width, 3)) ##transformer_units = [