Transformer+CNN structure is added to vision transformer type

pull/18/head
vahidrezanezhad 6 months ago
parent f1fd74c7eb
commit 743f2e97d6

@ -2,9 +2,9 @@
"backbone_type" : "transformer", "backbone_type" : "transformer",
"task": "binarization", "task": "binarization",
"n_classes" : 2, "n_classes" : 2,
"n_epochs" : 1, "n_epochs" : 2,
"input_height" : 224, "input_height" : 224,
"input_width" : 672, "input_width" : 224,
"weight_decay" : 1e-6, "weight_decay" : 1e-6,
"n_batch" : 1, "n_batch" : 1,
"learning_rate": 1e-4, "learning_rate": 1e-4,
@ -22,10 +22,14 @@
"scaling_flip" : false, "scaling_flip" : false,
"rotation": false, "rotation": false,
"rotation_not_90": false, "rotation_not_90": false,
"transformer_num_patches_xy": [7, 7], "transformer_num_patches_xy": [56, 56],
"transformer_patchsize_x": 3, "transformer_patchsize_x": 4,
"transformer_patchsize_y": 1, "transformer_patchsize_y": 4,
"transformer_projection_dim": 192, "transformer_projection_dim": 64,
"transformer_mlp_head_units": [128, 64],
"transformer_layers": 1,
"transformer_num_heads": 1,
"transformer_cnn_first": false,
"blur_k" : ["blur","guass","median"], "blur_k" : ["blur","guass","median"],
"scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
"brightness" : [1.3, 1.5, 1.7, 2], "brightness" : [1.3, 1.5, 1.7, 2],

@ -5,10 +5,10 @@ from tensorflow.keras.layers import *
from tensorflow.keras import layers from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2 from tensorflow.keras.regularizers import l2
mlp_head_units = [2048, 1024] ##mlp_head_units = [512, 256]#[2048, 1024]
#projection_dim = 64 ###projection_dim = 64
transformer_layers = 8 ##transformer_layers = 2#8
num_heads = 4 ##num_heads = 1#4
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
IMAGE_ORDERING = 'channels_last' IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1 MERGE_AXIS = -1
@ -36,7 +36,8 @@ class Patches(layers.Layer):
rates=[1, 1, 1, 1], rates=[1, 1, 1, 1],
padding="VALID", padding="VALID",
) )
patch_dims = patches.shape[-1] #patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims]) patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches return patches
def get_config(self): def get_config(self):
@ -393,13 +394,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
return model return model
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
inputs = layers.Input(shape=(input_height, input_width, 3)) inputs = layers.Input(shape=(input_height, input_width, 3))
transformer_units = [ #transformer_units = [
projection_dim * 2, #projection_dim * 2,
projection_dim, #projection_dim,
] # Size of the transformer layers #] # Size of the transformer layers
IMAGE_ORDERING = 'channels_last' IMAGE_ORDERING = 'channels_last'
bn_axis=3 bn_axis=3
@ -459,7 +460,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projec
# Layer normalization 2. # Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2) x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP. # MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1) x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
# Skip connection 2. # Skip connection 2.
encoded_patches = layers.Add()([x3, x2]) encoded_patches = layers.Add()([x3, x2])
@ -515,6 +516,125 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projec
return model return model
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
inputs = layers.Input(shape=(input_height, input_width, 3))
##transformer_units = [
##projection_dim * 2,
##projection_dim,
##] # Size of the transformer layers
IMAGE_ORDERING = 'channels_last'
bn_axis=3
patches = Patches(patch_size_x, patch_size_y)(inputs)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
f2 = one_side_pad(x)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
f5 = x
if pretraining:
model = Model(encoded_patches, x).load_weights(resnet50_Weights_path)
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x)
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
v1024_2048 = Activation('relu')(v1024_2048)
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
o = (concatenate([o, f4],axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o ,f3], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f2], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f1], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, inputs],axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else:
o = (Activation('sigmoid'))(o)
model = Model(inputs=inputs, outputs=o)
return model
def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
include_top=True include_top=True
assert input_height%32 == 0 assert input_height%32 == 0

@ -70,10 +70,14 @@ def config_params():
brightness = None # Brighten image for augmentation. brightness = None # Brighten image for augmentation.
flip_index = None # Flip image for augmentation. flip_index = None # Flip image for augmentation.
continue_training = False # Set to true if you would like to continue training an already trained a model. continue_training = False # Set to true if you would like to continue training an already trained a model.
transformer_patchsize_x = None # Patch size of vision transformer patches. transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
transformer_patchsize_y = None transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
transformer_num_patches_xy = None # Number of patches for vision transformer. transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively.
transformer_projection_dim = 64 # Transformer projection dimension transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64.
transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]
transformer_layers = 8 # transformer layers. Default value is 8.
transformer_num_heads = 4 # Transformer number of heads. Default value is 4.
transformer_cnn_first = True # We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true.
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
@ -94,7 +98,9 @@ def run(_config, n_classes, n_epochs, input_height,
brightening, binarization, blur_k, scales, degrade_scales, brightening, binarization, blur_k, scales, degrade_scales,
brightness, dir_train, data_is_provided, scaling_bluring, brightness, dir_train, data_is_provided, scaling_bluring,
scaling_brightness, scaling_binarization, rotation, rotation_not_90, scaling_brightness, scaling_binarization, rotation, rotation_not_90,
thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_patchsize_x, transformer_patchsize_y, thetha, scaling_flip, continue_training, transformer_projection_dim,
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
transformer_patchsize_x, transformer_patchsize_y,
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
@ -218,26 +224,33 @@ def run(_config, n_classes, n_epochs, input_height,
num_patches_y = transformer_num_patches_xy[1] num_patches_y = transformer_num_patches_xy[1]
num_patches = num_patches_x * num_patches_y num_patches = num_patches_x * num_patches_y
##if not (num_patches == (input_width / 32) * (input_height / 32)): if transformer_cnn_first:
##print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) if (input_height != (num_patches_y * transformer_patchsize_y * 32) ):
##sys.exit(1) print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)")
#if not (transformer_patchsize == 1): sys.exit(1)
#print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) if (input_width != (num_patches_x * transformer_patchsize_x * 32) ):
#sys.exit(1) print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)")
if (input_height != (num_patches_y * transformer_patchsize_y * 32) ): sys.exit(1)
print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)") if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
sys.exit(1) print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero")
if (input_width != (num_patches_x * transformer_patchsize_x * 32) ): sys.exit(1)
print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)")
sys.exit(1)
if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") else:
sys.exit(1) if (input_height != (num_patches_y * transformer_patchsize_y) ):
print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y)")
model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) sys.exit(1)
if (input_width != (num_patches_x * transformer_patchsize_x) ):
print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x)")
sys.exit(1)
if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero")
sys.exit(1)
model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
#if you want to see the model structure just uncomment model summary. #if you want to see the model structure just uncomment model summary.
#model.summary() model.summary()
if (task == "segmentation" or task == "binarization"): if (task == "segmentation" or task == "binarization"):

Loading…
Cancel
Save