transformer patch size is dynamic now.

This commit is contained in:
vahidrezanezhad 2024-06-12 13:26:27 +02:00
parent 2aa216e388
commit f1fd74c7eb
3 changed files with 75 additions and 30 deletions

View file

@ -6,25 +6,49 @@ from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2
mlp_head_units = [2048, 1024]
projection_dim = 64
#projection_dim = 64
transformer_layers = 8
num_heads = 4
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1
transformer_units = [
projection_dim * 2,
projection_dim,
] # Size of the transformer layers
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class Patches_old(layers.Layer):
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
@ -369,8 +393,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
return model
def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
inputs = layers.Input(shape=(input_height, input_width, 3))
transformer_units = [
projection_dim * 2,
projection_dim,
] # Size of the transformer layers
IMAGE_ORDERING = 'channels_last'
bn_axis=3
@ -414,7 +443,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
#patch_size_y = input_height / x.shape[1]
#patch_size_x = input_width / x.shape[2]
#patch_size = patch_size_x * patch_size_y
patches = Patches(patch_size)(x)
patches = Patches(patch_size_x, patch_size_y)(x)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
@ -434,7 +463,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64])
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)