mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
fix Patches/PatchEncoder (make configurable again)
This commit is contained in:
parent
2492c257c6
commit
ea285124ce
2 changed files with 26 additions and 48 deletions
|
|
@ -3,52 +3,44 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
|
|
||||||
projection_dim = 64
|
|
||||||
patch_size = 1
|
|
||||||
num_patches =21*21#14*14#28*28#14*14#28*28
|
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
class PatchEncoder(layers.Layer):
|
||||||
|
|
||||||
def __init__(self):
|
# 441=21*21 # 14*14 # 28*28
|
||||||
|
def __init__(self, num_patches=441, projection_dim=64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.projection = layers.Dense(units=projection_dim)
|
self.num_patches = num_patches
|
||||||
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
|
self.projection_dim = projection_dim
|
||||||
|
self.projection = layers.Dense(self.projection_dim)
|
||||||
|
self.position_embedding = layers.Embedding(self.num_patches, self.projection_dim)
|
||||||
|
|
||||||
def call(self, patch):
|
def call(self, patch):
|
||||||
positions = tf.range(start=0, limit=num_patches, delta=1)
|
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
return self.projection(patch) + self.position_embedding(positions)
|
||||||
return encoded
|
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = super().get_config().copy()
|
return dict(num_patches=self.num_patches,
|
||||||
config.update({
|
projection_dim=self.projection_dim,
|
||||||
'num_patches': num_patches,
|
**super().get_config())
|
||||||
'projection': self.projection,
|
|
||||||
'position_embedding': self.position_embedding,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
class Patches(layers.Layer):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, patch_size_x=1, patch_size_y=1):
|
||||||
super(Patches, self).__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size_x = patch_size_x
|
||||||
|
self.patch_size_y = patch_size_y
|
||||||
|
|
||||||
def call(self, images):
|
def call(self, images):
|
||||||
batch_size = tf.shape(images)[0]
|
batch_size = tf.shape(images)[0]
|
||||||
patches = tf.image.extract_patches(
|
patches = tf.image.extract_patches(
|
||||||
images=images,
|
images=images,
|
||||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
strides=[1, self.patch_size, self.patch_size, 1],
|
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
rates=[1, 1, 1, 1],
|
rates=[1, 1, 1, 1],
|
||||||
padding="VALID",
|
padding="VALID",
|
||||||
)
|
)
|
||||||
patch_dims = patches.shape[-1]
|
patch_dims = patches.shape[-1]
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
return tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
return patches
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
def get_config(self):
|
||||||
config.update({
|
return dict(patch_size_x=self.patch_size_x,
|
||||||
'patch_size': self.patch_size,
|
patch_size_y=self.patch_size_y,
|
||||||
})
|
**super().get_config())
|
||||||
return config
|
|
||||||
|
|
|
||||||
|
|
@ -423,16 +423,9 @@ def vit_resnet50_unet(num_patches,
|
||||||
|
|
||||||
#num_patches = x.shape[1]*x.shape[2]
|
#num_patches = x.shape[1]*x.shape[2]
|
||||||
|
|
||||||
# rs: fixme patch size not configurable anymore...
|
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x)
|
||||||
#patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
|
|
||||||
patches = Patches()(x)
|
|
||||||
assert transformer_patchsize_x == transformer_patchsize_y == 1
|
|
||||||
# Encode patches.
|
# Encode patches.
|
||||||
# rs: fixme num patches and dim not configurable anymore...
|
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
||||||
#encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
|
||||||
encoded_patches = PatchEncoder()(patches)
|
|
||||||
assert num_patches == 21 * 21
|
|
||||||
assert transformer_projection_dim == 64
|
|
||||||
|
|
||||||
for _ in range(transformer_layers):
|
for _ in range(transformer_layers):
|
||||||
# Layer normalization 1.
|
# Layer normalization 1.
|
||||||
|
|
@ -530,16 +523,9 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches,
|
||||||
IMAGE_ORDERING = 'channels_last'
|
IMAGE_ORDERING = 'channels_last'
|
||||||
bn_axis=3
|
bn_axis=3
|
||||||
|
|
||||||
# rs: fixme patch size not configurable anymore...
|
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
|
||||||
#patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
|
|
||||||
patches = Patches()(inputs)
|
|
||||||
assert transformer_patchsize_x == transformer_patchsize_y == 1
|
|
||||||
# Encode patches.
|
# Encode patches.
|
||||||
# rs: fixme num patches and dim not configurable anymore...
|
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
||||||
#encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
|
||||||
encoded_patches = PatchEncoder()(patches)
|
|
||||||
assert num_patches == 21 * 21
|
|
||||||
assert transformer_projection_dim == 64
|
|
||||||
|
|
||||||
for _ in range(transformer_layers):
|
for _ in range(transformer_layers):
|
||||||
# Layer normalization 1.
|
# Layer normalization 1.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue