diff --git a/config_params.json b/config_params.json index 6b8b6ed..d72530e 100644 --- a/config_params.json +++ b/config_params.json @@ -2,9 +2,9 @@ "backbone_type" : "transformer", "task": "binarization", "n_classes" : 2, - "n_epochs" : 1, + "n_epochs" : 2, "input_height" : 224, - "input_width" : 672, + "input_width" : 224, "weight_decay" : 1e-6, "n_batch" : 1, "learning_rate": 1e-4, @@ -22,10 +22,14 @@ "scaling_flip" : false, "rotation": false, "rotation_not_90": false, - "transformer_num_patches_xy": [7, 7], - "transformer_patchsize_x": 3, - "transformer_patchsize_y": 1, - "transformer_projection_dim": 192, + "transformer_num_patches_xy": [56, 56], + "transformer_patchsize_x": 4, + "transformer_patchsize_y": 4, + "transformer_projection_dim": 64, + "transformer_mlp_head_units": [128, 64], + "transformer_layers": 1, + "transformer_num_heads": 1, + "transformer_cnn_first": false, "blur_k" : ["blur","guass","median"], "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], "brightness" : [1.3, 1.5, 1.7, 2], diff --git a/models.py b/models.py index 1abf304..8841bd3 100644 --- a/models.py +++ b/models.py @@ -5,10 +5,10 @@ from tensorflow.keras.layers import * from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 -mlp_head_units = [2048, 1024] -#projection_dim = 64 -transformer_layers = 8 -num_heads = 4 +##mlp_head_units = [512, 256]#[2048, 1024] +###projection_dim = 64 +##transformer_layers = 2#8 +##num_heads = 1#4 resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' IMAGE_ORDERING = 'channels_last' MERGE_AXIS = -1 @@ -36,7 +36,8 @@ class Patches(layers.Layer): rates=[1, 1, 1, 1], padding="VALID", ) - patch_dims = patches.shape[-1] + #patch_dims = patches.shape[-1] + patch_dims = tf.shape(patches)[-1] patches = tf.reshape(patches, [batch_size, -1, patch_dims]) return patches def get_config(self): @@ -393,13 +394,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati return model -def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): +def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): inputs = layers.Input(shape=(input_height, input_width, 3)) - transformer_units = [ - projection_dim * 2, - projection_dim, - ] # Size of the transformer layers + #transformer_units = [ + #projection_dim * 2, + #projection_dim, + #] # Size of the transformer layers IMAGE_ORDERING = 'channels_last' bn_axis=3 @@ -459,7 +460,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projec # Layer normalization 2. x3 = layers.LayerNormalization(epsilon=1e-6)(x2) # MLP. - x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1) + x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) # Skip connection 2. encoded_patches = layers.Add()([x3, x2]) @@ -515,6 +516,125 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projec return model +def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=[128, 64], transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): + inputs = layers.Input(shape=(input_height, input_width, 3)) + + ##transformer_units = [ + ##projection_dim * 2, + ##projection_dim, + ##] # Size of the transformer layers + IMAGE_ORDERING = 'channels_last' + bn_axis=3 + + patches = Patches(patch_size_x, patch_size_y)(inputs) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + for _ in range(transformer_layers): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + )(x1, x1) + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = layers.Add()([x3, x2]) + + encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )]) + + encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches) + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + model = Model(encoded_patches, x).load_weights(resnet50_Weights_path) + + v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x) + v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) + o = (concatenate([o, f4],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o ,f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, inputs],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) + if task == "segmentation": + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + else: + o = (Activation('sigmoid'))(o) + + model = Model(inputs=inputs, outputs=o) + + return model + def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): include_top=True assert input_height%32 == 0 diff --git a/train.py b/train.py index bafcc9e..71f31f3 100644 --- a/train.py +++ b/train.py @@ -70,10 +70,14 @@ def config_params(): brightness = None # Brighten image for augmentation. flip_index = None # Flip image for augmentation. continue_training = False # Set to true if you would like to continue training an already trained a model. - transformer_patchsize_x = None # Patch size of vision transformer patches. - transformer_patchsize_y = None - transformer_num_patches_xy = None # Number of patches for vision transformer. - transformer_projection_dim = 64 # Transformer projection dimension + transformer_patchsize_x = None # Patch size of vision transformer patches in x direction. + transformer_patchsize_y = None # Patch size of vision transformer patches in y direction. + transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively. + transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64. + transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64] + transformer_layers = 8 # transformer layers. Default value is 8. + transformer_num_heads = 4 # Transformer number of heads. Default value is 4. + transformer_cnn_first = True # We have two types of vision transformers. In one type, a CNN is applied first, followed by a transformer. In the other type, this order is reversed. If transformer_cnn_first is true, it means the CNN will be applied before the transformer. Default value is true. index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. @@ -94,7 +98,9 @@ def run(_config, n_classes, n_epochs, input_height, brightening, binarization, blur_k, scales, degrade_scales, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, - thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_patchsize_x, transformer_patchsize_y, + thetha, scaling_flip, continue_training, transformer_projection_dim, + transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, + transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): @@ -218,26 +224,33 @@ def run(_config, n_classes, n_epochs, input_height, num_patches_y = transformer_num_patches_xy[1] num_patches = num_patches_x * num_patches_y - ##if not (num_patches == (input_width / 32) * (input_height / 32)): - ##print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) - ##sys.exit(1) - #if not (transformer_patchsize == 1): - #print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) - #sys.exit(1) - if (input_height != (num_patches_y * transformer_patchsize_y * 32) ): - print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)") - sys.exit(1) - if (input_width != (num_patches_x * transformer_patchsize_x * 32) ): - print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)") - sys.exit(1) - if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: - print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") - sys.exit(1) + if transformer_cnn_first: + if (input_height != (num_patches_y * transformer_patchsize_y * 32) ): + print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)") + sys.exit(1) + if (input_width != (num_patches_x * transformer_patchsize_x * 32) ): + print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)") + sys.exit(1) + if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: + print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") + sys.exit(1) + - model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) + model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) + else: + if (input_height != (num_patches_y * transformer_patchsize_y) ): + print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y)") + sys.exit(1) + if (input_width != (num_patches_x * transformer_patchsize_x) ): + print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x)") + sys.exit(1) + if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: + print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") + sys.exit(1) + model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary. - #model.summary() + model.summary() if (task == "segmentation" or task == "binarization"):