mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
patches class for VIT encoder is corrected
This commit is contained in:
parent
fed005abd7
commit
7f7bdab208
1 changed files with 31 additions and 35 deletions
|
|
@ -1,52 +1,48 @@
|
||||||
from keras import layers
|
from keras import layers
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
projection_dim = 64
|
|
||||||
patch_size = 1
|
class Patches(layers.Layer):
|
||||||
num_patches =21*21#14*14#28*28#14*14#28*28
|
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||||
|
super(Patches, self).__init__()
|
||||||
|
self.patch_size_x = patch_size_x
|
||||||
|
self.patch_size_y = patch_size_y
|
||||||
|
|
||||||
|
def call(self, images):
|
||||||
|
#print(tf.shape(images)[1],'images')
|
||||||
|
#print(self.patch_size,'self.patch_size')
|
||||||
|
batch_size = tf.shape(images)[0]
|
||||||
|
patches = tf.image.extract_patches(
|
||||||
|
images=images,
|
||||||
|
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
|
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
)
|
||||||
|
#patch_dims = patches.shape[-1]
|
||||||
|
patch_dims = tf.shape(patches)[-1]
|
||||||
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
|
return patches
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
class PatchEncoder(layers.Layer):
|
||||||
|
def __init__(self, num_patches, projection_dim):
|
||||||
def __init__(self):
|
super(PatchEncoder, self).__init__()
|
||||||
super().__init__()
|
self.num_patches = num_patches
|
||||||
self.projection = layers.Dense(units=projection_dim)
|
self.projection = layers.Dense(units=projection_dim)
|
||||||
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
|
self.position_embedding = layers.Embedding(
|
||||||
|
input_dim=num_patches, output_dim=projection_dim
|
||||||
|
)
|
||||||
|
|
||||||
def call(self, patch):
|
def call(self, patch):
|
||||||
positions = tf.range(start=0, limit=num_patches, delta=1)
|
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||||
return encoded
|
return encoded
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
|
|
||||||
config = super().get_config().copy()
|
config = super().get_config().copy()
|
||||||
config.update({
|
config.update({
|
||||||
'num_patches': num_patches,
|
'num_patches': self.num_patches,
|
||||||
'projection': self.projection,
|
'projection': self.projection,
|
||||||
'position_embedding': self.position_embedding,
|
'position_embedding': self.position_embedding,
|
||||||
})
|
})
|
||||||
return config
|
return config
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(Patches, self).__init__()
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
def call(self, images):
|
|
||||||
batch_size = tf.shape(images)[0]
|
|
||||||
patches = tf.image.extract_patches(
|
|
||||||
images=images,
|
|
||||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
|
||||||
strides=[1, self.patch_size, self.patch_size, 1],
|
|
||||||
rates=[1, 1, 1, 1],
|
|
||||||
padding="VALID",
|
|
||||||
)
|
|
||||||
patch_dims = patches.shape[-1]
|
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
|
||||||
return patches
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'patch_size': self.patch_size,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue