training.train: simplify config args for model builder

This commit is contained in:
Robert Sachunsky 2026-02-05 11:56:11 +01:00
parent 4a65ee0c67
commit 5c7801a1d6
2 changed files with 63 additions and 37 deletions

View file

@ -400,9 +400,21 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
return model
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
if mlp_head_units is None:
mlp_head_units = [128, 64]
def vit_resnet50_unet(num_patches,
n_classes,
transformer_patchsize_x,
transformer_patchsize_y,
transformer_mlp_head_units=None,
transformer_layers=8,
transformer_num_heads=4,
transformer_projection_dim=64,
input_height=224,
input_width=224,
task="segmentation",
weight_decay=1e-6,
pretraining=False):
if transformer_mlp_head_units is None:
transformer_mlp_head_units = [128, 64]
inputs = layers.Input(shape=(input_height, input_width, 3))
#transformer_units = [
@ -449,30 +461,30 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
#num_patches = x.shape[1]*x.shape[2]
#patch_size_y = input_height / x.shape[1]
#patch_size_x = input_width / x.shape[2]
#patch_size = patch_size_x * patch_size_y
patches = Patches(patch_size_x, patch_size_y)(x)
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
encoded_patches = tf.reshape(encoded_patches,
[-1, x.shape[1], x.shape[2],
transformer_projection_dim // (transformer_patchsize_x *
transformer_patchsize_y)])
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
@ -524,9 +536,21 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
return model
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
if mlp_head_units is None:
mlp_head_units = [128, 64]
def vit_resnet50_unet_transformer_before_cnn(num_patches,
n_classes,
transformer_patchsize_x,
transformer_patchsize_y,
transformer_mlp_head_units=None,
transformer_layers=8,
transformer_num_heads=4,
transformer_projection_dim=64,
input_height=224,
input_width=224,
task="segmentation",
weight_decay=1e-6,
pretraining=False):
if transformer_mlp_head_units is None:
transformer_mlp_head_units = [128, 64]
inputs = layers.Input(shape=(input_height, input_width, 3))
##transformer_units = [
@ -536,27 +560,32 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size
IMAGE_ORDERING = 'channels_last'
bn_axis=3
patches = Patches(patch_size_x, patch_size_y)(inputs)
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
encoded_patches = tf.reshape(encoded_patches,
[-1,
input_height,
input_width,
transformer_projection_dim // (transformer_patchsize_x *
transformer_patchsize_y)])
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)

View file

@ -38,6 +38,7 @@ from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from sacred import Experiment
from sacred.config import create_captured_function
from tqdm import tqdm
from sklearn.metrics import f1_score
@ -318,7 +319,7 @@ def run(_config,
task,
weight_decay,
pretraining)
elif backbone_type == 'transformer':
else:
num_patches_x = transformer_num_patches_xy[0]
num_patches_y = transformer_num_patches_xy[1]
num_patches = num_patches_x * num_patches_y
@ -330,35 +331,31 @@ def run(_config,
model_builder = vit_resnet50_unet_transformer_before_cnn
multiple_of_32 = False
assert input_height == num_patches_y * transformer_patchsize_y * (32 if multiple_of_32 else 1), \
assert input_height == (num_patches_y *
transformer_patchsize_y *
(32 if multiple_of_32 else 1)), \
"transformer_patchsize_y or transformer_num_patches_xy height value error: " \
"input_height should be equal to " \
"(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \
" * 32" if multiple_of_32 else ""
assert input_width == num_patches_x * transformer_patchsize_x * (32 if multiple_of_32 else 1), \
assert input_width == (num_patches_x *
transformer_patchsize_x *
(32 if multiple_of_32 else 1)), \
"transformer_patchsize_x or transformer_num_patches_xy width value error: " \
"input_width should be equal to " \
"(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \
" * 32" if multiple_of_32 else ""
assert 0 == transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x), \
assert 0 == (transformer_projection_dim %
(transformer_patchsize_y *
transformer_patchsize_x)), \
"transformer_projection_dim error: " \
"The remainder when parameter transformer_projection_dim is divided by " \
"(transformer_patchsize_y*transformer_patchsize_x) should be zero"
model = model_builder(
n_classes,
transformer_patchsize_x,
transformer_patchsize_y,
num_patches,
transformer_mlp_head_units,
transformer_layers,
transformer_num_heads,
transformer_projection_dim,
input_height,
input_width,
task,
weight_decay,
pretraining)
model_builder = create_captured_function(model_builder)
model_builder.config = _config
model_builder.logger = _log
model = model_builder(num_patches)
#if you want to see the model structure just uncomment model summary.
#model.summary()