training.models.vit_resnet50_unet: re-use IMAGE_ORDERING

This commit is contained in:
Robert Sachunsky 2026-02-17 14:18:32 +01:00
parent 7888fa5968
commit 4414f7b89b

View file

@ -372,12 +372,10 @@ def vit_resnet50_unet(num_patches,
transformer_mlp_head_units = [128, 64] transformer_mlp_head_units = [128, 64]
inputs = Input(shape=(input_height, input_width, 3)) inputs = Input(shape=(input_height, input_width, 3))
#transformer_units = [ if IMAGE_ORDERING == 'channels_last':
#projection_dim * 2, bn_axis = 3
#projection_dim, else:
#] # Size of the transformer layers bn_axis = 1
IMAGE_ORDERING = 'channels_last'
bn_axis=3
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs) x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
@ -508,12 +506,10 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches,
transformer_mlp_head_units = [128, 64] transformer_mlp_head_units = [128, 64]
inputs = Input(shape=(input_height, input_width, 3)) inputs = Input(shape=(input_height, input_width, 3))
##transformer_units = [ if IMAGE_ORDERING == 'channels_last':
##projection_dim * 2, bn_axis = 3
##projection_dim, else:
##] # Size of the transformer layers bn_axis = 1
IMAGE_ORDERING = 'channels_last'
bn_axis=3
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
# Encode patches. # Encode patches.