From 70af00182b2332f33e7872b6abc1af9bbba787bc Mon Sep 17 00:00:00 2001 From: cneud <952378+cneud@users.noreply.github.com> Date: Wed, 1 Oct 2025 00:20:18 +0200 Subject: [PATCH] mutable defaults are the source of all evil --- train/models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train/models.py b/train/models.py index 8841bd3..fdc5437 100644 --- a/train/models.py +++ b/train/models.py @@ -394,7 +394,9 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati return model -def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): +def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + if mlp_head_units is None: + mlp_head_units = [128, 64] inputs = layers.Input(shape=(input_height, input_width, 3)) #transformer_units = [ @@ -516,7 +518,9 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he return model -def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): +def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + if mlp_head_units is None: + mlp_head_units = [128, 64] inputs = layers.Input(shape=(input_height, input_width, 3)) ##transformer_units = [