mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.models: re-use transformer builder code
This commit is contained in:
parent
daa084c367
commit
9b66867c21
1 changed files with 53 additions and 56 deletions
|
|
@ -285,6 +285,41 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
||||||
|
|
||||||
return unet_decoder(img_input, *features, n_classes, light=False, task=task, weight_decay=weight_decay)
|
return unet_decoder(img_input, *features, n_classes, light=False, task=task, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
def transformer_block(img,
|
||||||
|
num_patches,
|
||||||
|
patchsize_x,
|
||||||
|
patchsize_y,
|
||||||
|
mlp_head_units,
|
||||||
|
n_layers,
|
||||||
|
num_heads,
|
||||||
|
projection_dim):
|
||||||
|
patches = Patches(patchsize_x, patchsize_y)(img)
|
||||||
|
# Encode patches.
|
||||||
|
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||||
|
|
||||||
|
for _ in range(n_layers):
|
||||||
|
# Layer normalization 1.
|
||||||
|
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||||
|
# Create a multi-head attention layer.
|
||||||
|
attention_output = MultiHeadAttention(num_heads=num_heads,
|
||||||
|
key_dim=projection_dim,
|
||||||
|
dropout=0.1)(x1, x1)
|
||||||
|
# Skip connection 1.
|
||||||
|
x2 = Add()([attention_output, encoded_patches])
|
||||||
|
# Layer normalization 2.
|
||||||
|
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||||
|
# MLP.
|
||||||
|
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||||
|
# Skip connection 2.
|
||||||
|
encoded_patches = Add()([x3, x2])
|
||||||
|
|
||||||
|
encoded_patches = tf.reshape(encoded_patches,
|
||||||
|
[-1,
|
||||||
|
img.shape[1],
|
||||||
|
img.shape[2],
|
||||||
|
projection_dim // (patchsize_x * patchsize_y)])
|
||||||
|
return encoded_patches
|
||||||
|
|
||||||
def vit_resnet50_unet(num_patches,
|
def vit_resnet50_unet(num_patches,
|
||||||
n_classes,
|
n_classes,
|
||||||
transformer_patchsize_x,
|
transformer_patchsize_x,
|
||||||
|
|
@ -304,33 +339,14 @@ def vit_resnet50_unet(num_patches,
|
||||||
|
|
||||||
features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining)
|
features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining)
|
||||||
|
|
||||||
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(features[-1])
|
features[-1] = transformer_block(features[-1],
|
||||||
# Encode patches.
|
num_patches,
|
||||||
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
transformer_patchsize_x,
|
||||||
|
transformer_patchsize_y,
|
||||||
for _ in range(transformer_layers):
|
transformer_mlp_head_units,
|
||||||
# Layer normalization 1.
|
transformer_layers,
|
||||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
transformer_num_heads,
|
||||||
# Create a multi-head attention layer.
|
transformer_projection_dim)
|
||||||
attention_output = MultiHeadAttention(
|
|
||||||
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
|
|
||||||
)(x1, x1)
|
|
||||||
# Skip connection 1.
|
|
||||||
x2 = Add()([attention_output, encoded_patches])
|
|
||||||
# Layer normalization 2.
|
|
||||||
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
|
||||||
# MLP.
|
|
||||||
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
|
|
||||||
# Skip connection 2.
|
|
||||||
encoded_patches = Add()([x3, x2])
|
|
||||||
|
|
||||||
encoded_patches = tf.reshape(encoded_patches,
|
|
||||||
[-1,
|
|
||||||
features[-1].shape[1],
|
|
||||||
features[-1].shape[2],
|
|
||||||
transformer_projection_dim // (transformer_patchsize_x *
|
|
||||||
transformer_patchsize_y)])
|
|
||||||
features[-1] = encoded_patches
|
|
||||||
|
|
||||||
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
|
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
|
@ -352,38 +368,19 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches,
|
||||||
if transformer_mlp_head_units is None:
|
if transformer_mlp_head_units is None:
|
||||||
transformer_mlp_head_units = [128, 64]
|
transformer_mlp_head_units = [128, 64]
|
||||||
inputs = Input(shape=(input_height, input_width, 3))
|
inputs = Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
|
encoded_patches = transformer_block(inputs,
|
||||||
# Encode patches.
|
num_patches,
|
||||||
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
transformer_patchsize_x,
|
||||||
|
transformer_patchsize_y,
|
||||||
for _ in range(transformer_layers):
|
transformer_mlp_head_units,
|
||||||
# Layer normalization 1.
|
transformer_layers,
|
||||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
transformer_num_heads,
|
||||||
# Create a multi-head attention layer.
|
transformer_projection_dim)
|
||||||
attention_output = MultiHeadAttention(
|
|
||||||
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
|
|
||||||
)(x1, x1)
|
|
||||||
# Skip connection 1.
|
|
||||||
x2 = Add()([attention_output, encoded_patches])
|
|
||||||
# Layer normalization 2.
|
|
||||||
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
|
||||||
# MLP.
|
|
||||||
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
|
|
||||||
# Skip connection 2.
|
|
||||||
encoded_patches = Add()([x3, x2])
|
|
||||||
|
|
||||||
encoded_patches = tf.reshape(encoded_patches,
|
|
||||||
[-1,
|
|
||||||
input_height,
|
|
||||||
input_width,
|
|
||||||
transformer_projection_dim // (transformer_patchsize_x *
|
|
||||||
transformer_patchsize_y)])
|
|
||||||
|
|
||||||
encoded_patches = Conv2D(3, (1, 1), padding='same',
|
encoded_patches = Conv2D(3, (1, 1), padding='same',
|
||||||
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay),
|
data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay),
|
||||||
name='convinput')(encoded_patches)
|
name='convinput')(encoded_patches)
|
||||||
|
|
||||||
features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining)
|
features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining)
|
||||||
|
|
||||||
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
|
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue