diff --git a/config_params.json b/config_params.json index 8a56de5..6b8b6ed 100644 --- a/config_params.json +++ b/config_params.json @@ -1,42 +1,44 @@ { - "backbone_type" : "nontransformer", - "task": "classification", + "backbone_type" : "transformer", + "task": "binarization", "n_classes" : 2, - "n_epochs" : 20, - "input_height" : 448, - "input_width" : 448, + "n_epochs" : 1, + "input_height" : 224, + "input_width" : 672, "weight_decay" : 1e-6, - "n_batch" : 6, + "n_batch" : 1, "learning_rate": 1e-4, - "f1_threshold_classification": 0.8, "patches" : true, "pretraining" : true, "augmentation" : false, "flip_aug" : false, "blur_aug" : false, "scaling" : true, + "degrading": false, + "brightening": false, "binarization" : false, "scaling_bluring" : false, "scaling_binarization" : false, "scaling_flip" : false, "rotation": false, "rotation_not_90": false, - "transformer_num_patches_xy": [28, 28], - "transformer_patchsize": 1, + "transformer_num_patches_xy": [7, 7], + "transformer_patchsize_x": 3, + "transformer_patchsize_y": 1, + "transformer_projection_dim": 192, "blur_k" : ["blur","guass","median"], "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], "brightness" : [1.3, 1.5, 1.7, 2], "degrade_scales" : [0.2, 0.4], "flip_index" : [0, 1, -1], "thetha" : [10, -10], - "classification_classes_name" : {"0":"apple", "1":"orange"}, "continue_training": false, "index_start" : 0, "dir_of_start_model" : " ", "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, - "dir_train": "./train", - "dir_eval": "./eval", - "dir_output": "./output" + "dir_train": "/home/vahid/Documents/test/training_data_sample_binarization", + "dir_eval": "/home/vahid/Documents/test/eval", + "dir_output": "/home/vahid/Documents/test/out" } diff --git a/models.py b/models.py index b8b0d27..1abf304 100644 --- a/models.py +++ b/models.py @@ -6,25 +6,49 @@ from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 mlp_head_units = [2048, 1024] -projection_dim = 64 +#projection_dim = 64 transformer_layers = 8 num_heads = 4 resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' IMAGE_ORDERING = 'channels_last' MERGE_AXIS = -1 -transformer_units = [ - projection_dim * 2, - projection_dim, -] # Size of the transformer layers def mlp(x, hidden_units, dropout_rate): for units in hidden_units: x = layers.Dense(units, activation=tf.nn.gelu)(x) x = layers.Dropout(dropout_rate)(x) return x - class Patches(layers.Layer): + def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size_x = patch_size_x + self.patch_size_y = patch_size_y + + def call(self, images): + #print(tf.shape(images)[1],'images') + #print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size_y, self.patch_size_x, 1], + strides=[1, self.patch_size_y, self.patch_size_x, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size_x': self.patch_size_x, + 'patch_size_y': self.patch_size_y, + }) + return config + +class Patches_old(layers.Layer): def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): super(Patches, self).__init__() self.patch_size = patch_size @@ -369,8 +393,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati return model -def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): +def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): inputs = layers.Input(shape=(input_height, input_width, 3)) + + transformer_units = [ + projection_dim * 2, + projection_dim, + ] # Size of the transformer layers IMAGE_ORDERING = 'channels_last' bn_axis=3 @@ -414,7 +443,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu #patch_size_y = input_height / x.shape[1] #patch_size_x = input_width / x.shape[2] #patch_size = patch_size_x * patch_size_y - patches = Patches(patch_size)(x) + patches = Patches(patch_size_x, patch_size_y)(x) # Encode patches. encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) @@ -434,7 +463,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu # Skip connection 2. encoded_patches = layers.Add()([x3, x2]) - encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64]) + encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )]) v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) diff --git a/train.py b/train.py index 9e06a66..bafcc9e 100644 --- a/train.py +++ b/train.py @@ -70,8 +70,10 @@ def config_params(): brightness = None # Brighten image for augmentation. flip_index = None # Flip image for augmentation. continue_training = False # Set to true if you would like to continue training an already trained a model. - transformer_patchsize = None # Patch size of vision transformer patches. + transformer_patchsize_x = None # Patch size of vision transformer patches. + transformer_patchsize_y = None transformer_num_patches_xy = None # Number of patches for vision transformer. + transformer_projection_dim = 64 # Transformer projection dimension index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. @@ -92,7 +94,7 @@ def run(_config, n_classes, n_epochs, input_height, brightening, binarization, blur_k, scales, degrade_scales, brightness, dir_train, data_is_provided, scaling_bluring, scaling_brightness, scaling_binarization, rotation, rotation_not_90, - thetha, scaling_flip, continue_training, transformer_patchsize, + thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): @@ -212,15 +214,27 @@ def run(_config, n_classes, n_epochs, input_height, if backbone_type=='nontransformer': model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining) elif backbone_type=='transformer': - num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1] + num_patches_x = transformer_num_patches_xy[0] + num_patches_y = transformer_num_patches_xy[1] + num_patches = num_patches_x * num_patches_y - if not (num_patches == (input_width / 32) * (input_height / 32)): - print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) + ##if not (num_patches == (input_width / 32) * (input_height / 32)): + ##print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) + ##sys.exit(1) + #if not (transformer_patchsize == 1): + #print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) + #sys.exit(1) + if (input_height != (num_patches_y * transformer_patchsize_y * 32) ): + print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)") sys.exit(1) - if not (transformer_patchsize == 1): - print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) + if (input_width != (num_patches_x * transformer_patchsize_x * 32) ): + print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)") sys.exit(1) - model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining) + if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0: + print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero") + sys.exit(1) + + model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary. #model.summary()