training.models: re-use transformer builder code

This commit is contained in:
Robert Sachunsky 2026-02-17 17:35:20 +01:00
parent daa084c367
commit 9b66867c21

View file

@ -285,6 +285,41 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
return unet_decoder(img_input, *features, n_classes, light=False, task=task, weight_decay=weight_decay)
def transformer_block(img,
num_patches,
patchsize_x,
patchsize_y,
mlp_head_units,
n_layers,
num_heads,
projection_dim):
patches = Patches(patchsize_x, patchsize_y)(img)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(n_layers):
# Layer normalization 1.
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = MultiHeadAttention(num_heads=num_heads,
key_dim=projection_dim,
dropout=0.1)(x1, x1)
# Skip connection 1.
x2 = Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches,
[-1,
img.shape[1],
img.shape[2],
projection_dim // (patchsize_x * patchsize_y)])
return encoded_patches
def vit_resnet50_unet(num_patches,
n_classes,
transformer_patchsize_x,
@ -304,33 +339,14 @@ def vit_resnet50_unet(num_patches,
features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining)
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(features[-1])
# Encode patches.
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = MultiHeadAttention(
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches,
[-1,
features[-1].shape[1],
features[-1].shape[2],
transformer_projection_dim // (transformer_patchsize_x *
transformer_patchsize_y)])
features[-1] = encoded_patches
features[-1] = transformer_block(features[-1],
num_patches,
transformer_patchsize_x,
transformer_patchsize_y,
transformer_mlp_head_units,
transformer_layers,
transformer_num_heads,
transformer_projection_dim)
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
@ -353,33 +369,14 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches,
transformer_mlp_head_units = [128, 64]
inputs = Input(shape=(input_height, input_width, 3))
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = MultiHeadAttention(
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches,
[-1,
input_height,
input_width,
transformer_projection_dim // (transformer_patchsize_x *
transformer_patchsize_y)])
encoded_patches = transformer_block(inputs,
num_patches,
transformer_patchsize_x,
transformer_patchsize_y,
transformer_mlp_head_units,
transformer_layers,
transformer_num_heads,
transformer_projection_dim)
encoded_patches = Conv2D(3, (1, 1), padding='same',
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay),
name='convinput')(encoded_patches)