@ -30,8 +30,8 @@ class Patches(layers.Layer):
self . patch_size = patch_size
self . patch_size = patch_size
def call ( self , images ) :
def call ( self , images ) :
print ( tf . shape ( images ) [ 1 ] , ' images ' )
#print(tf.shape(images)[1],'images' )
print ( self . patch_size , ' self.patch_size ' )
#print(self.patch_size,'self.patch_size' )
batch_size = tf . shape ( images ) [ 0 ]
batch_size = tf . shape ( images ) [ 0 ]
patches = tf . image . extract_patches (
patches = tf . image . extract_patches (
images = images ,
images = images ,
@ -41,7 +41,7 @@ class Patches(layers.Layer):
padding = " VALID " ,
padding = " VALID " ,
)
)
patch_dims = patches . shape [ - 1 ]
patch_dims = patches . shape [ - 1 ]
print ( patches . shape , patch_dims , ' patch_dims ' )
#print(patches.shape,patch_dims,'patch_dims' )
patches = tf . reshape ( patches , [ batch_size , - 1 , patch_dims ] )
patches = tf . reshape ( patches , [ batch_size , - 1 , patch_dims ] )
return patches
return patches
def get_config ( self ) :
def get_config ( self ) :
@ -51,6 +51,7 @@ class Patches(layers.Layer):
' patch_size ' : self . patch_size ,
' patch_size ' : self . patch_size ,
} )
} )
return config
return config
class PatchEncoder ( layers . Layer ) :
class PatchEncoder ( layers . Layer ) :
def __init__ ( self , num_patches , projection_dim ) :
def __init__ ( self , num_patches , projection_dim ) :
@ -408,7 +409,11 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
if pretraining :
if pretraining :
model = Model ( inputs , x ) . load_weights ( resnet50_Weights_path )
model = Model ( inputs , x ) . load_weights ( resnet50_Weights_path )
num_patches = x . shape [ 1 ] * x . shape [ 2 ]
#num_patches = x.shape[1]*x.shape[2]
#patch_size_y = input_height / x.shape[1]
#patch_size_x = input_width / x.shape[2]
#patch_size = patch_size_x * patch_size_y
patches = Patches ( patch_size ) ( x )
patches = Patches ( patch_size ) ( x )
# Encode patches.
# Encode patches.
encoded_patches = PatchEncoder ( num_patches , projection_dim ) ( patches )
encoded_patches = PatchEncoder ( num_patches , projection_dim ) ( patches )