From 5c7801a1d6273cd88b64548edf41507e5c0235d6 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Thu, 5 Feb 2026 11:56:11 +0100 Subject: [PATCH] training.train: simplify config args for model builder --- src/eynollah/training/models.py | 67 +++++++++++++++++++++++---------- src/eynollah/training/train.py | 33 ++++++++-------- 2 files changed, 63 insertions(+), 37 deletions(-) diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 011c614..f053447 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -400,9 +400,21 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati return model -def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): - if mlp_head_units is None: - mlp_head_units = [128, 64] +def vit_resnet50_unet(num_patches, + n_classes, + transformer_patchsize_x, + transformer_patchsize_y, + transformer_mlp_head_units=None, + transformer_layers=8, + transformer_num_heads=4, + transformer_projection_dim=64, + input_height=224, + input_width=224, + task="segmentation", + weight_decay=1e-6, + pretraining=False): + if transformer_mlp_head_units is None: + transformer_mlp_head_units = [128, 64] inputs = layers.Input(shape=(input_height, input_width, 3)) #transformer_units = [ @@ -449,30 +461,30 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he #num_patches = x.shape[1]*x.shape[2] - #patch_size_y = input_height / x.shape[1] - #patch_size_x = input_width / x.shape[2] - #patch_size = patch_size_x * patch_size_y - patches = Patches(patch_size_x, patch_size_y)(x) + patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x) # Encode patches. - encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) for _ in range(transformer_layers): # Layer normalization 1. x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) # Create a multi-head attention layer. attention_output = layers.MultiHeadAttention( - num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1 )(x1, x1) # Skip connection 1. x2 = layers.Add()([attention_output, encoded_patches]) # Layer normalization 2. x3 = layers.LayerNormalization(epsilon=1e-6)(x2) # MLP. - x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) + x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1) # Skip connection 2. encoded_patches = layers.Add()([x3, x2]) - encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )]) + encoded_patches = tf.reshape(encoded_patches, + [-1, x.shape[1], x.shape[2], + transformer_projection_dim // (transformer_patchsize_x * + transformer_patchsize_y)]) v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) @@ -524,9 +536,21 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he return model -def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): - if mlp_head_units is None: - mlp_head_units = [128, 64] +def vit_resnet50_unet_transformer_before_cnn(num_patches, + n_classes, + transformer_patchsize_x, + transformer_patchsize_y, + transformer_mlp_head_units=None, + transformer_layers=8, + transformer_num_heads=4, + transformer_projection_dim=64, + input_height=224, + input_width=224, + task="segmentation", + weight_decay=1e-6, + pretraining=False): + if transformer_mlp_head_units is None: + transformer_mlp_head_units = [128, 64] inputs = layers.Input(shape=(input_height, input_width, 3)) ##transformer_units = [ @@ -536,27 +560,32 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size IMAGE_ORDERING = 'channels_last' bn_axis=3 - patches = Patches(patch_size_x, patch_size_y)(inputs) + patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) # Encode patches. - encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) for _ in range(transformer_layers): # Layer normalization 1. x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) # Create a multi-head attention layer. attention_output = layers.MultiHeadAttention( - num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1 )(x1, x1) # Skip connection 1. x2 = layers.Add()([attention_output, encoded_patches]) # Layer normalization 2. x3 = layers.LayerNormalization(epsilon=1e-6)(x2) # MLP. - x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) + x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1) # Skip connection 2. encoded_patches = layers.Add()([x3, x2]) - encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )]) + encoded_patches = tf.reshape(encoded_patches, + [-1, + input_height, + input_width, + transformer_projection_dim // (transformer_patchsize_x * + transformer_patchsize_y)]) encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index a21a34d..4aafcf2 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -38,6 +38,7 @@ from tensorflow.keras.metrics import MeanIoU from tensorflow.keras.models import load_model from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from sacred import Experiment +from sacred.config import create_captured_function from tqdm import tqdm from sklearn.metrics import f1_score @@ -318,7 +319,7 @@ def run(_config, task, weight_decay, pretraining) - elif backbone_type == 'transformer': + else: num_patches_x = transformer_num_patches_xy[0] num_patches_y = transformer_num_patches_xy[1] num_patches = num_patches_x * num_patches_y @@ -330,35 +331,31 @@ def run(_config, model_builder = vit_resnet50_unet_transformer_before_cnn multiple_of_32 = False - assert input_height == num_patches_y * transformer_patchsize_y * (32 if multiple_of_32 else 1), \ + assert input_height == (num_patches_y * + transformer_patchsize_y * + (32 if multiple_of_32 else 1)), \ "transformer_patchsize_y or transformer_num_patches_xy height value error: " \ "input_height should be equal to " \ "(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \ " * 32" if multiple_of_32 else "" - assert input_width == num_patches_x * transformer_patchsize_x * (32 if multiple_of_32 else 1), \ + assert input_width == (num_patches_x * + transformer_patchsize_x * + (32 if multiple_of_32 else 1)), \ "transformer_patchsize_x or transformer_num_patches_xy width value error: " \ "input_width should be equal to " \ "(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \ " * 32" if multiple_of_32 else "" - assert 0 == transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x), \ + assert 0 == (transformer_projection_dim % + (transformer_patchsize_y * + transformer_patchsize_x)), \ "transformer_projection_dim error: " \ "The remainder when parameter transformer_projection_dim is divided by " \ "(transformer_patchsize_y*transformer_patchsize_x) should be zero" - model = model_builder( - n_classes, - transformer_patchsize_x, - transformer_patchsize_y, - num_patches, - transformer_mlp_head_units, - transformer_layers, - transformer_num_heads, - transformer_projection_dim, - input_height, - input_width, - task, - weight_decay, - pretraining) + model_builder = create_captured_function(model_builder) + model_builder.config = _config + model_builder.logger = _log + model = model_builder(num_patches) #if you want to see the model structure just uncomment model summary. #model.summary()