diff --git a/models.py b/models.py index d852ac3..b8b0d27 100644 --- a/models.py +++ b/models.py @@ -30,8 +30,8 @@ class Patches(layers.Layer): self.patch_size = patch_size def call(self, images): - print(tf.shape(images)[1],'images') - print(self.patch_size,'self.patch_size') + #print(tf.shape(images)[1],'images') + #print(self.patch_size,'self.patch_size') batch_size = tf.shape(images)[0] patches = tf.image.extract_patches( images=images, @@ -41,7 +41,7 @@ class Patches(layers.Layer): padding="VALID", ) patch_dims = patches.shape[-1] - print(patches.shape,patch_dims,'patch_dims') + #print(patches.shape,patch_dims,'patch_dims') patches = tf.reshape(patches, [batch_size, -1, patch_dims]) return patches def get_config(self): @@ -51,6 +51,7 @@ class Patches(layers.Layer): 'patch_size': self.patch_size, }) return config + class PatchEncoder(layers.Layer): def __init__(self, num_patches, projection_dim): @@ -408,7 +409,11 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu if pretraining: model = Model(inputs, x).load_weights(resnet50_Weights_path) - num_patches = x.shape[1]*x.shape[2] + #num_patches = x.shape[1]*x.shape[2] + + #patch_size_y = input_height / x.shape[1] + #patch_size_x = input_width / x.shape[2] + #patch_size = patch_size_x * patch_size_y patches = Patches(patch_size)(x) # Encode patches. encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) diff --git a/train.py b/train.py index e16745f..84c9d3b 100644 --- a/train.py +++ b/train.py @@ -97,8 +97,6 @@ def run(_config, n_classes, n_epochs, input_height, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): if task == "segmentation" or task == "enhancement": - - num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1] if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') @@ -213,7 +211,15 @@ def run(_config, n_classes, n_epochs, input_height, index_start = 0 if backbone_type=='nontransformer': model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining) - elif backbone_type=='nontransformer': + elif backbone_type=='transformer': + num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1] + + if not (num_patches == (input_width / 32) * (input_height / 32)): + print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) ) + sys.exit(1) + if not (transformer_patchsize == 1): + print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" ) + sys.exit(1) model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining) #if you want to see the model structure just uncomment model summary.