From 662aa67dfbc3a55639fc1fbd5f39038262a785fb Mon Sep 17 00:00:00 2001 From: kba Date: Thu, 16 Oct 2025 20:31:48 +0200 Subject: [PATCH 1/3] move models.py to root to cherry-pick 3098700 --- src/eynollah/training/models.py => models.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/eynollah/training/models.py => models.py (100%) diff --git a/src/eynollah/training/models.py b/models.py similarity index 100% rename from src/eynollah/training/models.py rename to models.py From b67a3c4ed4abbee13bada4ecb96c2303a83dff49 Mon Sep 17 00:00:00 2001 From: "H.T. Kruitbosch" Date: Thu, 11 Jan 2024 19:04:42 +0100 Subject: [PATCH 2/3] tf.keras version that allows any input resolution --- models.py | 779 ++++++++---------------------------------------------- 1 file changed, 116 insertions(+), 663 deletions(-) diff --git a/models.py b/models.py index fdc5437..aba310c 100644 --- a/models.py +++ b/models.py @@ -1,117 +1,12 @@ -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras.models import * -from tensorflow.keras.layers import * +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Conv2D, Concatenate, ZeroPadding2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D, Input, Layer from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 +import tensorflow as tf -##mlp_head_units = [512, 256]#[2048, 1024] -###projection_dim = 64 -##transformer_layers = 2#8 -##num_heads = 1#4 -resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' -IMAGE_ORDERING = 'channels_last' -MERGE_AXIS = -1 +resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' -def mlp(x, hidden_units, dropout_rate): - for units in hidden_units: - x = layers.Dense(units, activation=tf.nn.gelu)(x) - x = layers.Dropout(dropout_rate)(x) - return x - -class Patches(layers.Layer): - def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size_x = patch_size_x - self.patch_size_y = patch_size_y - - def call(self, images): - #print(tf.shape(images)[1],'images') - #print(self.patch_size,'self.patch_size') - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size_y, self.patch_size_x, 1], - strides=[1, self.patch_size_y, self.patch_size_x, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - #patch_dims = patches.shape[-1] - patch_dims = tf.shape(patches)[-1] - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size_x': self.patch_size_x, - 'patch_size_y': self.patch_size_y, - }) - return config - -class Patches_old(layers.Layer): - def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size - - def call(self, images): - #print(tf.shape(images)[1],'images') - #print(self.patch_size,'self.patch_size') - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - patch_dims = patches.shape[-1] - #print(patches.shape,patch_dims,'patch_dims') - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config - - -class PatchEncoder(layers.Layer): - def __init__(self, num_patches, projection_dim): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config - - -def one_side_pad(x): - x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) - if IMAGE_ORDERING == 'channels_first': - x = Lambda(lambda x: x[:, :, :-1, :-1])(x) - elif IMAGE_ORDERING == 'channels_last': - x = Lambda(lambda x: x[:, :-1, :-1, :])(x) - return x - - -def identity_block(input_tensor, kernel_size, filters, stage, block): +def identity_block(input_tensor, kernel_size, filters, stage, block, data_format='channels_last'): """The identity block is the block that has no conv layer at shortcut. # Arguments input_tensor: input tensor @@ -123,25 +18,22 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): Output tensor for the block. """ filters1, filters2, filters3 = filters - - if IMAGE_ORDERING == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 + + bn_axis = 3 if data_format == 'channels_last' else 1 conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' - x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2a')(input_tensor) + x = Conv2D(filters1, (1, 1) , data_format=data_format , name=conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) - x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, + x = Conv2D(filters2, kernel_size , data_format=data_format , padding='same', name=conv_name_base + '2b')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) - x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x) + x = Conv2D(filters3 , (1, 1), data_format=data_format , name=conv_name_base + '2c')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) x = layers.add([x, input_tensor]) @@ -149,7 +41,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): return x -def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): +def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2), data_format='channels_last'): """conv_block is the block that has a conv layer at shortcut # Arguments input_tensor: input tensor @@ -163,29 +55,26 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) And the shortcut should have strides=(2,2) as well """ filters1, filters2, filters3 = filters - - if IMAGE_ORDERING == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 + + bn_axis = 3 if data_format == 'channels_last' else 1 conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' - x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides, + x = Conv2D(filters1, (1, 1) , data_format=data_format, strides=strides, name=conv_name_base + '2a')(input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) - x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same', + x = Conv2D(filters2, kernel_size , data_format=data_format, padding='same', name=conv_name_base + '2b')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) - x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x) + x = Conv2D(filters3, (1, 1), data_format=data_format, name=conv_name_base + '2c')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) - shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, strides=strides, + shortcut = Conv2D(filters3, (1, 1), data_format=data_format, strides=strides, name=conv_name_base + '1')(input_tensor) shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) @@ -194,567 +83,131 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) return x -def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False): - assert input_height % 32 == 0 - assert input_width % 32 == 0 +class PadMultiple(Layer): + def __init__(self, mods, data_format='channels_last'): + super().__init__() + self.mods = mods + self.data_format = data_format + + def call(self, x): + h, w = self.mods + padding = ( + [(0,0), (0, -tf.shape(x)[1] % h), (0, -tf.shape(x)[2] % w), (0,0)] if self.data_format == 'channels_last' + else [(0,0), (0,0), (0, -tf.shape(x)[1] % h), (0, -tf.shape(x)[2] % w)]) + return tf.pad(x, padding) - img_input = Input(shape=(input_height, input_width, 3)) - if IMAGE_ORDERING == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 +class CutTo(Layer): + def __init__(self, data_format='channels_last'): + super().__init__() + self.data_format = data_format + + def call(self, inputs): + h, w = (1, 2) if self.data_format == 'channels_last' else (2,4) + h, w = tf.shape(inputs[1])[h], tf.shape(inputs[1])[w] + return inputs[0][:, :h, :w] if self.data_format == 'channels_last' else inputs[0][:, :, :h, :w] - x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) - x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay), - name='conv1')(x) + +def resnet50_unet(n_classes, input_height=None, input_width=None, weight_decay=1e-6, pretraining=False, last_activation='softmax', skip_last_batchnorm=False, light_version=False, data_format='channels_last'): + """ Returns a U-NET model using the keras functional API. """ + img_input = Input(shape=(input_height, input_width, 3 )) + padded_to_multiple = PadMultiple((32,32))(img_input) + + bn_axis = 3 if data_format == 'channels_last' else 1 + merge_axis = 3 if data_format == 'channels_last' else 1 + + x = ZeroPadding2D((3, 3), data_format=data_format)(padded_to_multiple) + x = Conv2D(64, (7, 7), data_format=data_format, strides=(2, 2), kernel_regularizer=l2(weight_decay), name='conv1')(x) f1 = x x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) x = Activation('relu')(x) - x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) - - x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) - x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') - x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') - f2 = one_side_pad(x) - - x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') - f3 = x - - x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') - f4 = x - - x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') - f5 = x - - if pretraining: - model = Model(img_input, x).load_weights(resnet50_Weights_path) - - v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5) - v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048) - v512_2048 = Activation('relu')(v512_2048) - - v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4) - v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024) - v512_1024 = Activation('relu')(v512_1024) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048) - o = (concatenate([o, v512_1024], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f3], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f2], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f1], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, img_input], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) - if task == "segmentation": - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) - else: - o = (Activation('sigmoid'))(o) - - model = Model(img_input, o) - return model - - -def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): - assert input_height % 32 == 0 - assert input_width % 32 == 0 - - img_input = Input(shape=(input_height, input_width, 3)) - - if IMAGE_ORDERING == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 - - x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) - x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay), - name='conv1')(x) - f1 = x - - x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) - x = Activation('relu')(x) - x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) - - x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) - x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') - x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') - f2 = one_side_pad(x) - - x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') - f3 = x - - x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') - f4 = x - - x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') - f5 = x - - if pretraining: - Model(img_input, x).load_weights(resnet50_Weights_path) - - v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))( - f5) - v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) - v1024_2048 = Activation('relu')(v1024_2048) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048) - o = (concatenate([o, f4], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f3], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f2], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f1], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, img_input], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) - if task == "segmentation": - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) - else: - o = (Activation('sigmoid'))(o) - - model = Model(img_input, o) - - return model - - -def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): - if mlp_head_units is None: - mlp_head_units = [128, 64] - inputs = layers.Input(shape=(input_height, input_width, 3)) + x = MaxPooling2D((3, 3) , data_format=data_format , strides=(2, 2))(x) - #transformer_units = [ - #projection_dim * 2, - #projection_dim, - #] # Size of the transformer layers - IMAGE_ORDERING = 'channels_last' - bn_axis=3 - x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs) - x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) - f1 = x + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), data_format=data_format) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', data_format=data_format) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', data_format=data_format) + f2 = ZeroPadding2D(((1,0), (1,0)), data_format=data_format)(x) - x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) - x = Activation('relu')(x) - x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) - - x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) - x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') - x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') - f2 = one_side_pad(x) - x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', data_format=data_format) + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', data_format=data_format) + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', data_format=data_format) + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', data_format=data_format) f3 = x - x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', data_format=data_format) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', data_format=data_format) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', data_format=data_format) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', data_format=data_format) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', data_format=data_format) + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', data_format=data_format) f4 = x - x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') - f5 = x - - if pretraining: - model = Model(inputs, x).load_weights(resnet50_Weights_path) - - #num_patches = x.shape[1]*x.shape[2] - - #patch_size_y = input_height / x.shape[1] - #patch_size_x = input_width / x.shape[2] - #patch_size = patch_size_x * patch_size_y - patches = Patches(patch_size_x, patch_size_y)(x) - # Encode patches. - encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) - - for _ in range(transformer_layers): - # Layer normalization 1. - x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) - # Create a multi-head attention layer. - attention_output = layers.MultiHeadAttention( - num_heads=num_heads, key_dim=projection_dim, dropout=0.1 - )(x1, x1) - # Skip connection 1. - x2 = layers.Add()([attention_output, encoded_patches]) - # Layer normalization 2. - x3 = layers.LayerNormalization(epsilon=1e-6)(x2) - # MLP. - x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) - # Skip connection 2. - encoded_patches = layers.Add()([x3, x2]) - - encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )]) - - v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) - v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) - v1024_2048 = Activation('relu')(v1024_2048) - - o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) - o = (concatenate([o, f4],axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o ,f3], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f2], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f1], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, inputs],axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) - if task == "segmentation": - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) - else: - o = (Activation('sigmoid'))(o) - - model = Model(inputs=inputs, outputs=o) - - return model - -def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): - if mlp_head_units is None: - mlp_head_units = [128, 64] - inputs = layers.Input(shape=(input_height, input_width, 3)) - - ##transformer_units = [ - ##projection_dim * 2, - ##projection_dim, - ##] # Size of the transformer layers - IMAGE_ORDERING = 'channels_last' - bn_axis=3 - - patches = Patches(patch_size_x, patch_size_y)(inputs) - # Encode patches. - encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) - - for _ in range(transformer_layers): - # Layer normalization 1. - x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) - # Create a multi-head attention layer. - attention_output = layers.MultiHeadAttention( - num_heads=num_heads, key_dim=projection_dim, dropout=0.1 - )(x1, x1) - # Skip connection 1. - x2 = layers.Add()([attention_output, encoded_patches]) - # Layer normalization 2. - x3 = layers.LayerNormalization(epsilon=1e-6)(x2) - # MLP. - x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1) - # Skip connection 2. - encoded_patches = layers.Add()([x3, x2]) - - encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )]) - - encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches) - - x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(encoded_patches) - x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) - f1 = x - - x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) - x = Activation('relu')(x) - x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) - - x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) - x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') - x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') - f2 = one_side_pad(x) - - x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') - f3 = x - - x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') - f4 = x - - x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') - f5 = x - - if pretraining: - model = Model(encoded_patches, x).load_weights(resnet50_Weights_path) - - v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(x) - v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) - v1024_2048 = Activation('relu')(v1024_2048) - - o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) - o = (concatenate([o, f4],axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o ,f3], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f2], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f1], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, inputs],axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) - if task == "segmentation": - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) - else: - o = (Activation('sigmoid'))(o) - - model = Model(inputs=inputs, outputs=o) - - return model - -def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): - include_top=True - assert input_height%32 == 0 - assert input_width%32 == 0 - - - img_input = Input(shape=(input_height,input_width , 3 )) - - if IMAGE_ORDERING == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 - - x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) - x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) - f1 = x - - x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) - x = Activation('relu')(x) - x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) - - - x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) - x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') - x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') - f2 = one_side_pad(x ) - - - x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') - f3 = x - - x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') - f4 = x - - x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', data_format=data_format) + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', data_format=data_format) + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', data_format=data_format) f5 = x if pretraining: Model(img_input, x).load_weights(resnet50_Weights_path) - x = AveragePooling2D((7, 7), name='avg_pool')(x) - x = Flatten()(x) - - ## - x = Dense(256, activation='relu', name='fc512')(x) - x=Dropout(0.2)(x) - ## - x = Dense(n_classes, activation='softmax', name='fc1000')(x) - model = Model(img_input, x) - - + if light_version: + v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f5) + v512_2048 = BatchNormalization(axis=bn_axis)(v512_2048) + v512_2048 = Activation('relu')(v512_2048) - - return model - -def machine_based_reading_order_model(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): - assert input_height%32 == 0 - assert input_width%32 == 0 - - img_input = Input(shape=(input_height,input_width , 3 )) - - if IMAGE_ORDERING == 'channels_last': - bn_axis = 3 + v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f4) + v512_1024 = BatchNormalization(axis=bn_axis)(v512_1024) + v512_1024 = Activation('relu')(v512_1024) + x, c = v512_2048, v512_1024 # continuation and concatenation layers else: - bn_axis = 1 - - x1 = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) - x1 = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x1) - - x1 = BatchNormalization(axis=bn_axis, name='bn_conv1')(x1) - x1 = Activation('relu')(x1) - x1 = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x1) + v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f5) + v1024_2048 = BatchNormalization(axis=bn_axis)(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + x, c = v1024_2048, f4 # continuation and concatenation layers - x1 = conv_block(x1, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) - x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='b') - x1 = identity_block(x1, 3, [64, 64, 256], stage=2, block='c') + o = UpSampling2D((2,2), data_format=data_format)(x) + o = Concatenate(axis=merge_axis)([o ,c]) + o = ZeroPadding2D( (1,1), data_format=data_format)(o) + o = Conv2D(512, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) - x1 = conv_block(x1, 3, [128, 128, 512], stage=3, block='a') - x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='b') - x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='c') - x1 = identity_block(x1, 3, [128, 128, 512], stage=3, block='d') + o = UpSampling2D( (2,2), data_format=data_format)(o) + o = Concatenate(axis=merge_axis)([ o ,f3]) + o = ZeroPadding2D( (1,1), data_format=data_format)(o) + o = Conv2D( 256, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) - x1 = conv_block(x1, 3, [256, 256, 1024], stage=4, block='a') - x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='b') - x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='c') - x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='d') - x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='e') - x1 = identity_block(x1, 3, [256, 256, 1024], stage=4, block='f') + o = UpSampling2D( (2,2), data_format=data_format)(o) + o = Concatenate(axis=merge_axis)([o,f2]) + o = ZeroPadding2D((1,1) , data_format=data_format)(o) + o = Conv2D( 128 , (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) - x1 = conv_block(x1, 3, [512, 512, 2048], stage=5, block='a') - x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='b') - x1 = identity_block(x1, 3, [512, 512, 2048], stage=5, block='c') + o = UpSampling2D( (2,2), data_format=data_format)(o) + o = Concatenate(axis=merge_axis)([o,f1]) + o = ZeroPadding2D((1,1) , data_format=data_format)(o) + o = Conv2D( 64 , (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) + + o = UpSampling2D( (2,2), data_format=data_format)(o) + o = Concatenate(axis=merge_axis)([o, padded_to_multiple]) + o = ZeroPadding2D((1,1) , data_format=data_format)(o) + o = Conv2D(32, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) - if pretraining: - Model(img_input , x1).load_weights(resnet50_Weights_path) + o = Conv2D(n_classes, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(o) + if not skip_last_batchnorm: + o = BatchNormalization(axis=bn_axis)(o) - x1 = AveragePooling2D((7, 7), name='avg_pool1')(x1) - flattened = Flatten()(x1) + o = Activation(last_activation)(o) + o = CutTo()([o, img_input]) - o = Dense(256, activation='relu', name='fc512')(flattened) - o=Dropout(0.2)(o) - - o = Dense(256, activation='relu', name='fc512a')(o) - o=Dropout(0.2)(o) - - o = Dense(n_classes, activation='sigmoid', name='fc1000')(o) - model = Model(img_input , o) - - return model + return Model(img_input , o) From 2e0c1868e0ffaad3080a99a70d8f7cee6743d0c7 Mon Sep 17 00:00:00 2001 From: kba Date: Thu, 16 Oct 2025 20:36:16 +0200 Subject: [PATCH 3/3] move models.py back to src/.../training --- models.py => src/eynollah/training/models.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename models.py => src/eynollah/training/models.py (100%) diff --git a/models.py b/src/eynollah/training/models.py similarity index 100% rename from models.py rename to src/eynollah/training/models.py