Transformer+CNN structure is added to vision transformer type

This commit is contained in:
vahidrezanezhad 2024-06-12 17:39:57 +02:00
parent f1fd74c7eb
commit 743f2e97d6
3 changed files with 176 additions and 39 deletions

142
models.py
View file

@ -5,10 +5,10 @@ from tensorflow.keras.layers import *
from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2
mlp_head_units = [2048, 1024]
#projection_dim = 64
transformer_layers = 8
num_heads = 4
##mlp_head_units = [512, 256]#[2048, 1024]
###projection_dim = 64
##transformer_layers = 2#8
##num_heads = 1#4
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1
@ -36,7 +36,8 @@ class Patches(layers.Layer):
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
@ -393,13 +394,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
return model
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
inputs = layers.Input(shape=(input_height, input_width, 3))
transformer_units = [
projection_dim * 2,
projection_dim,
] # Size of the transformer layers
#transformer_units = [
#projection_dim * 2,
#projection_dim,
#] # Size of the transformer layers
IMAGE_ORDERING = 'channels_last'
bn_axis=3
@ -459,7 +460,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projec
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
@ -515,6 +516,125 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projec
return model
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
inputs = layers.Input(shape=(input_height, input_width, 3))
##transformer_units = [
##projection_dim * 2,
##projection_dim,
##] # Size of the transformer layers
IMAGE_ORDERING = 'channels_last'
bn_axis=3
patches = Patches(patch_size_x, patch_size_y)(inputs)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
f2 = one_side_pad(x)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
f5 = x
if pretraining:
model = Model(encoded_patches, x).load_weights(resnet50_Weights_path)
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x)
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
v1024_2048 = Activation('relu')(v1024_2048)
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
o = (concatenate([o, f4],axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o ,f3], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f2], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f1], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, inputs],axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else:
o = (Activation('sigmoid'))(o)
model = Model(inputs=inputs, outputs=o)
return model
def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
include_top=True
assert input_height%32 == 0