From 4414f7b89b4e1488a6955bb40342709ab05c0414 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Tue, 17 Feb 2026 14:18:32 +0100 Subject: [PATCH] training.models.vit_resnet50_unet: re-use `IMAGE_ORDERING` --- src/eynollah/training/models.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 6182c9e..0dc78d2 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -372,12 +372,10 @@ def vit_resnet50_unet(num_patches, transformer_mlp_head_units = [128, 64] inputs = Input(shape=(input_height, input_width, 3)) - #transformer_units = [ - #projection_dim * 2, - #projection_dim, - #] # Size of the transformer layers - IMAGE_ORDERING = 'channels_last' - bn_axis=3 + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs) x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) @@ -508,12 +506,10 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, transformer_mlp_head_units = [128, 64] inputs = Input(shape=(input_height, input_width, 3)) - ##transformer_units = [ - ##projection_dim * 2, - ##projection_dim, - ##] # Size of the transformer layers - IMAGE_ORDERING = 'channels_last' - bn_axis=3 + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) # Encode patches.