patches class for VIT encoder is corrected

This commit is contained in:
vahidrezanezhad 2026-03-01 18:26:29 +01:00
parent fed005abd7
commit 7f7bdab208

View file

@ -1,52 +1,48 @@
from keras import layers from keras import layers
import tensorflow as tf import tensorflow as tf
projection_dim = 64
patch_size = 1 class Patches(layers.Layer):
num_patches =21*21#14*14#28*28#14*14#28*28 def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
class PatchEncoder(layers.Layer): class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
def __init__(self): super(PatchEncoder, self).__init__()
super().__init__() self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim) self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim) self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch): def call(self, patch):
positions = tf.range(start=0, limit=num_patches, delta=1) positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions) encoded = self.projection(patch) + self.position_embedding(positions)
return encoded return encoded
def get_config(self): def get_config(self):
config = super().get_config().copy() config = super().get_config().copy()
config.update({ config.update({
'num_patches': num_patches, 'num_patches': self.num_patches,
'projection': self.projection, 'projection': self.projection,
'position_embedding': self.position_embedding, 'position_embedding': self.position_embedding,
}) })
return config return config
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config