diff --git a/config_params.json b/config_params.json
index 8a56de5..6b8b6ed 100644
--- a/config_params.json
+++ b/config_params.json
@@ -1,42 +1,44 @@
 {
-    "backbone_type" : "nontransformer",
-    "task": "classification",
+    "backbone_type" : "transformer",
+    "task": "binarization",
     "n_classes" : 2,
-    "n_epochs" : 20,
-    "input_height" : 448,
-    "input_width" : 448,
+    "n_epochs" : 1,
+    "input_height" : 224,
+    "input_width" : 672,
     "weight_decay" : 1e-6,
-    "n_batch" : 6,
+    "n_batch" : 1,
     "learning_rate": 1e-4,
-    "f1_threshold_classification": 0.8,
     "patches" : true,
     "pretraining" : true,
     "augmentation" : false,
     "flip_aug" : false,
     "blur_aug" : false,
     "scaling" : true,
+    "degrading": false,
+    "brightening": false,
     "binarization" : false,
     "scaling_bluring" : false,
     "scaling_binarization" : false,
     "scaling_flip" : false,
     "rotation": false,
     "rotation_not_90": false,
-    "transformer_num_patches_xy": [28, 28],
-    "transformer_patchsize": 1,
+    "transformer_num_patches_xy": [7, 7],
+    "transformer_patchsize_x": 3,
+    "transformer_patchsize_y": 1,
+    "transformer_projection_dim": 192,
     "blur_k" : ["blur","guass","median"],
     "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
     "brightness" : [1.3, 1.5, 1.7, 2],
     "degrade_scales" : [0.2, 0.4],
     "flip_index" : [0, 1, -1],
     "thetha" : [10, -10],
-    "classification_classes_name" : {"0":"apple",  "1":"orange"},
     "continue_training": false,
     "index_start" : 0,
     "dir_of_start_model" : " ", 
     "weighted_loss": false,
     "is_loss_soft_dice": false,
     "data_is_provided": false,
-    "dir_train": "./train",
-    "dir_eval": "./eval",
-    "dir_output": "./output"
+    "dir_train": "/home/vahid/Documents/test/training_data_sample_binarization",
+    "dir_eval": "/home/vahid/Documents/test/eval",
+    "dir_output": "/home/vahid/Documents/test/out"
 }
diff --git a/models.py b/models.py
index b8b0d27..1abf304 100644
--- a/models.py
+++ b/models.py
@@ -6,25 +6,49 @@ from tensorflow.keras import layers
 from tensorflow.keras.regularizers import l2
 
 mlp_head_units = [2048, 1024]
-projection_dim = 64
+#projection_dim = 64
 transformer_layers = 8
 num_heads = 4
 resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
 IMAGE_ORDERING = 'channels_last'
 MERGE_AXIS = -1
 
-transformer_units = [
-    projection_dim * 2,
-    projection_dim,
-]  # Size of the transformer layers
 def mlp(x, hidden_units, dropout_rate):
     for units in hidden_units:
         x = layers.Dense(units, activation=tf.nn.gelu)(x)
         x = layers.Dropout(dropout_rate)(x)
     return x
 
-
 class Patches(layers.Layer):
+    def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
+        super(Patches, self).__init__()
+        self.patch_size_x = patch_size_x
+        self.patch_size_y = patch_size_y
+
+    def call(self, images):
+        #print(tf.shape(images)[1],'images')
+        #print(self.patch_size,'self.patch_size')
+        batch_size = tf.shape(images)[0]
+        patches = tf.image.extract_patches(
+            images=images,
+            sizes=[1, self.patch_size_y, self.patch_size_x, 1],
+            strides=[1, self.patch_size_y, self.patch_size_x, 1],
+            rates=[1, 1, 1, 1],
+            padding="VALID",
+        )
+        patch_dims = patches.shape[-1]
+        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
+        return patches
+    def get_config(self):
+
+        config = super().get_config().copy()
+        config.update({
+            'patch_size_x': self.patch_size_x,
+            'patch_size_y': self.patch_size_y,
+        })
+        return config
+
+class Patches_old(layers.Layer):
     def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
         super(Patches, self).__init__()
         self.patch_size = patch_size
@@ -369,8 +393,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
     return model
 
 
-def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
+def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
     inputs = layers.Input(shape=(input_height, input_width, 3))
+    
+    transformer_units = [
+        projection_dim * 2,
+        projection_dim,
+    ]  # Size of the transformer layers
     IMAGE_ORDERING = 'channels_last'
     bn_axis=3
 
@@ -414,7 +443,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
     #patch_size_y = input_height / x.shape[1]
     #patch_size_x = input_width / x.shape[2]
     #patch_size = patch_size_x * patch_size_y
-    patches = Patches(patch_size)(x)
+    patches = Patches(patch_size_x, patch_size_y)(x)
     # Encode patches.
     encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
     
@@ -434,7 +463,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
         # Skip connection 2.
         encoded_patches = layers.Add()([x3, x2])
     
-    encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64])
+    encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
 
     v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
     v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
diff --git a/train.py b/train.py
index 9e06a66..bafcc9e 100644
--- a/train.py
+++ b/train.py
@@ -70,8 +70,10 @@ def config_params():
     brightness = None #  Brighten image for augmentation.
     flip_index = None  #  Flip image for augmentation.
     continue_training = False  # Set to true if you would like to continue training an already trained a model.
-    transformer_patchsize = None  # Patch size of vision transformer patches.
+    transformer_patchsize_x = None  # Patch size of vision transformer patches.
+    transformer_patchsize_y = None
     transformer_num_patches_xy = None  # Number of patches for vision transformer.
+    transformer_projection_dim = 64 # Transformer projection dimension
     index_start = 0  #  Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
     dir_of_start_model = ''  # Directory containing pretrained encoder to continue training the model.
     is_loss_soft_dice = False  # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
@@ -92,7 +94,7 @@ def run(_config, n_classes, n_epochs, input_height,
         brightening, binarization, blur_k, scales, degrade_scales,
         brightness, dir_train, data_is_provided, scaling_bluring,
         scaling_brightness, scaling_binarization, rotation, rotation_not_90,
-        thetha, scaling_flip, continue_training, transformer_patchsize,
+        thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_patchsize_x, transformer_patchsize_y,
         transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
         pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
     
@@ -212,15 +214,27 @@ def run(_config, n_classes, n_epochs, input_height,
             if backbone_type=='nontransformer':
                 model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
             elif backbone_type=='transformer':
-                num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
+                num_patches_x = transformer_num_patches_xy[0]
+                num_patches_y = transformer_num_patches_xy[1]
+                num_patches = num_patches_x * num_patches_y
                 
-                if not (num_patches == (input_width / 32) * (input_height / 32)):
-                    print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) =  {}".format(int(input_width / 32), int(input_height / 32)) )
+                ##if not (num_patches == (input_width / 32) * (input_height / 32)):
+                    ##print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) =  {}".format(int(input_width / 32), int(input_height / 32)) )
+                    ##sys.exit(1)
+                #if not (transformer_patchsize == 1):
+                    #print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" )
+                    #sys.exit(1)
+                if (input_height != (num_patches_y * transformer_patchsize_y * 32) ):
+                    print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)")
                     sys.exit(1)
-                if not (transformer_patchsize == 1):
-                    print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" )
+                if (input_width != (num_patches_x * transformer_patchsize_x * 32) ):
+                    print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)")
                     sys.exit(1)
-                model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
+                if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
+                    print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero")
+                    sys.exit(1)
+                    
+                model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
         
         #if you want to see the model structure just uncomment model summary.
         #model.summary()