diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index a03f028..4af4949 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -285,6 +285,41 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati return unet_decoder(img_input, *features, n_classes, light=False, task=task, weight_decay=weight_decay) +def transformer_block(img, + num_patches, + patchsize_x, + patchsize_y, + mlp_head_units, + n_layers, + num_heads, + projection_dim): + patches = Patches(patchsize_x, patchsize_y)(img) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + for _ in range(n_layers): + # Layer normalization 1. + x1 = LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = MultiHeadAttention(num_heads=num_heads, + key_dim=projection_dim, + dropout=0.1)(x1, x1) + # Skip connection 1. + x2 = Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = Add()([x3, x2]) + + encoded_patches = tf.reshape(encoded_patches, + [-1, + img.shape[1], + img.shape[2], + projection_dim // (patchsize_x * patchsize_y)]) + return encoded_patches + def vit_resnet50_unet(num_patches, n_classes, transformer_patchsize_x, @@ -304,33 +339,14 @@ def vit_resnet50_unet(num_patches, features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining) - patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(features[-1]) - # Encode patches. - encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) - - for _ in range(transformer_layers): - # Layer normalization 1. - x1 = LayerNormalization(epsilon=1e-6)(encoded_patches) - # Create a multi-head attention layer. - attention_output = MultiHeadAttention( - num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1 - )(x1, x1) - # Skip connection 1. - x2 = Add()([attention_output, encoded_patches]) - # Layer normalization 2. - x3 = LayerNormalization(epsilon=1e-6)(x2) - # MLP. - x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1) - # Skip connection 2. - encoded_patches = Add()([x3, x2]) - - encoded_patches = tf.reshape(encoded_patches, - [-1, - features[-1].shape[1], - features[-1].shape[2], - transformer_projection_dim // (transformer_patchsize_x * - transformer_patchsize_y)]) - features[-1] = encoded_patches + features[-1] = transformer_block(features[-1], + num_patches, + transformer_patchsize_x, + transformer_patchsize_y, + transformer_mlp_head_units, + transformer_layers, + transformer_num_heads, + transformer_projection_dim) o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) @@ -352,38 +368,19 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, if transformer_mlp_head_units is None: transformer_mlp_head_units = [128, 64] inputs = Input(shape=(input_height, input_width, 3)) - - patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) - # Encode patches. - encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) - - for _ in range(transformer_layers): - # Layer normalization 1. - x1 = LayerNormalization(epsilon=1e-6)(encoded_patches) - # Create a multi-head attention layer. - attention_output = MultiHeadAttention( - num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1 - )(x1, x1) - # Skip connection 1. - x2 = Add()([attention_output, encoded_patches]) - # Layer normalization 2. - x3 = LayerNormalization(epsilon=1e-6)(x2) - # MLP. - x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1) - # Skip connection 2. - encoded_patches = Add()([x3, x2]) - - encoded_patches = tf.reshape(encoded_patches, - [-1, - input_height, - input_width, - transformer_projection_dim // (transformer_patchsize_x * - transformer_patchsize_y)]) - + + encoded_patches = transformer_block(inputs, + num_patches, + transformer_patchsize_x, + transformer_patchsize_y, + transformer_mlp_head_units, + transformer_layers, + transformer_num_heads, + transformer_projection_dim) encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches) - + features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining) o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)