mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.models: re-use transformer builder code
This commit is contained in:
parent
daa084c367
commit
9b66867c21
1 changed files with 53 additions and 56 deletions
|
|
@ -285,6 +285,41 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
|||
|
||||
return unet_decoder(img_input, *features, n_classes, light=False, task=task, weight_decay=weight_decay)
|
||||
|
||||
def transformer_block(img,
|
||||
num_patches,
|
||||
patchsize_x,
|
||||
patchsize_y,
|
||||
mlp_head_units,
|
||||
n_layers,
|
||||
num_heads,
|
||||
projection_dim):
|
||||
patches = Patches(patchsize_x, patchsize_y)(img)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||
|
||||
for _ in range(n_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = MultiHeadAttention(num_heads=num_heads,
|
||||
key_dim=projection_dim,
|
||||
dropout=0.1)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches,
|
||||
[-1,
|
||||
img.shape[1],
|
||||
img.shape[2],
|
||||
projection_dim // (patchsize_x * patchsize_y)])
|
||||
return encoded_patches
|
||||
|
||||
def vit_resnet50_unet(num_patches,
|
||||
n_classes,
|
||||
transformer_patchsize_x,
|
||||
|
|
@ -304,33 +339,14 @@ def vit_resnet50_unet(num_patches,
|
|||
|
||||
features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining)
|
||||
|
||||
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(features[-1])
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
||||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = MultiHeadAttention(
|
||||
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches,
|
||||
[-1,
|
||||
features[-1].shape[1],
|
||||
features[-1].shape[2],
|
||||
transformer_projection_dim // (transformer_patchsize_x *
|
||||
transformer_patchsize_y)])
|
||||
features[-1] = encoded_patches
|
||||
features[-1] = transformer_block(features[-1],
|
||||
num_patches,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_mlp_head_units,
|
||||
transformer_layers,
|
||||
transformer_num_heads,
|
||||
transformer_projection_dim)
|
||||
|
||||
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
|
||||
|
||||
|
|
@ -352,38 +368,19 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches,
|
|||
if transformer_mlp_head_units is None:
|
||||
transformer_mlp_head_units = [128, 64]
|
||||
inputs = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
||||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = MultiHeadAttention(
|
||||
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches,
|
||||
[-1,
|
||||
input_height,
|
||||
input_width,
|
||||
transformer_projection_dim // (transformer_patchsize_x *
|
||||
transformer_patchsize_y)])
|
||||
|
||||
|
||||
encoded_patches = transformer_block(inputs,
|
||||
num_patches,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_mlp_head_units,
|
||||
transformer_layers,
|
||||
transformer_num_heads,
|
||||
transformer_projection_dim)
|
||||
encoded_patches = Conv2D(3, (1, 1), padding='same',
|
||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay),
|
||||
name='convinput')(encoded_patches)
|
||||
|
||||
|
||||
features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining)
|
||||
|
||||
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue