diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 6182c9e..0dc78d2 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -372,12 +372,10 @@ def vit_resnet50_unet(num_patches, transformer_mlp_head_units = [128, 64] inputs = Input(shape=(input_height, input_width, 3)) - #transformer_units = [ - #projection_dim * 2, - #projection_dim, - #] # Size of the transformer layers - IMAGE_ORDERING = 'channels_last' - bn_axis=3 + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs) x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) @@ -508,12 +506,10 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, transformer_mlp_head_units = [128, 64] inputs = Input(shape=(input_height, input_width, 3)) - ##transformer_units = [ - ##projection_dim * 2, - ##projection_dim, - ##] # Size of the transformer layers - IMAGE_ORDERING = 'channels_last' - bn_axis=3 + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) # Encode patches.