mutable defaults are the source of all evil

This commit is contained in:
cneud 2025-10-01 00:20:18 +02:00
parent 1d0616eb69
commit 70af00182b

View file

@ -394,7 +394,9 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
return model return model
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
if mlp_head_units is None:
mlp_head_units = [128, 64]
inputs = layers.Input(shape=(input_height, input_width, 3)) inputs = layers.Input(shape=(input_height, input_width, 3))
#transformer_units = [ #transformer_units = [
@ -516,7 +518,9 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
return model return model
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
if mlp_head_units is None:
mlp_head_units = [128, 64]
inputs = layers.Input(shape=(input_height, input_width, 3)) inputs = layers.Input(shape=(input_height, input_width, 3))
##transformer_units = [ ##transformer_units = [