From ea285124ce11aa9c00d02d2e939803a067931a61 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sun, 8 Feb 2026 01:06:57 +0100 Subject: [PATCH] fix Patches/PatchEncoder (make configurable again) --- src/eynollah/patch_encoder.py | 52 ++++++++++++++------------------- src/eynollah/training/models.py | 22 +++----------- 2 files changed, 26 insertions(+), 48 deletions(-) diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py index dc0a291..07b843d 100644 --- a/src/eynollah/patch_encoder.py +++ b/src/eynollah/patch_encoder.py @@ -3,52 +3,44 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf from tensorflow.keras import layers -projection_dim = 64 -patch_size = 1 -num_patches =21*21#14*14#28*28#14*14#28*28 - class PatchEncoder(layers.Layer): - def __init__(self): + # 441=21*21 # 14*14 # 28*28 + def __init__(self, num_patches=441, projection_dim=64): super().__init__() - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim) + self.num_patches = num_patches + self.projection_dim = projection_dim + self.projection = layers.Dense(self.projection_dim) + self.position_embedding = layers.Embedding(self.num_patches, self.projection_dim) def call(self, patch): - positions = tf.range(start=0, limit=num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded + positions = tf.range(start=0, limit=self.num_patches, delta=1) + return self.projection(patch) + self.position_embedding(positions) def get_config(self): - config = super().get_config().copy() - config.update({ - 'num_patches': num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config + return dict(num_patches=self.num_patches, + projection_dim=self.projection_dim, + **super().get_config()) class Patches(layers.Layer): - def __init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size + def __init__(self, patch_size_x=1, patch_size_y=1): + super().__init__() + self.patch_size_x = patch_size_x + self.patch_size_y = patch_size_y def call(self, images): batch_size = tf.shape(images)[0] patches = tf.image.extract_patches( images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], + sizes=[1, self.patch_size_y, self.patch_size_x, 1], + strides=[1, self.patch_size_y, self.patch_size_x, 1], rates=[1, 1, 1, 1], padding="VALID", ) patch_dims = patches.shape[-1] - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): + return tf.reshape(patches, [batch_size, -1, patch_dims]) - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config + def get_config(self): + return dict(patch_size_x=self.patch_size_x, + patch_size_y=self.patch_size_y, + **super().get_config()) diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index d1148f1..b0ad51c 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -423,16 +423,9 @@ def vit_resnet50_unet(num_patches, #num_patches = x.shape[1]*x.shape[2] - # rs: fixme patch size not configurable anymore... - #patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) - patches = Patches()(x) - assert transformer_patchsize_x == transformer_patchsize_y == 1 + patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x) # Encode patches. - # rs: fixme num patches and dim not configurable anymore... - #encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) - encoded_patches = PatchEncoder()(patches) - assert num_patches == 21 * 21 - assert transformer_projection_dim == 64 + encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) for _ in range(transformer_layers): # Layer normalization 1. @@ -530,16 +523,9 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, IMAGE_ORDERING = 'channels_last' bn_axis=3 - # rs: fixme patch size not configurable anymore... - #patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) - patches = Patches()(inputs) - assert transformer_patchsize_x == transformer_patchsize_y == 1 + patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) # Encode patches. - # rs: fixme num patches and dim not configurable anymore... - #encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) - encoded_patches = PatchEncoder()(patches) - assert num_patches == 21 * 21 - assert transformer_projection_dim == 64 + encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) for _ in range(transformer_layers): # Layer normalization 1.