mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 03:40:24 +02:00
first working update of branch
This commit is contained in:
parent
02b1436f39
commit
d27647a0f1
4 changed files with 452 additions and 151 deletions
179
models.py
179
models.py
|
@ -1,13 +1,81 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras.models import *
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.regularizers import l2
|
||||
|
||||
mlp_head_units = [2048, 1024]
|
||||
projection_dim = 64
|
||||
transformer_layers = 8
|
||||
num_heads = 4
|
||||
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
MERGE_AXIS = -1
|
||||
|
||||
transformer_units = [
|
||||
projection_dim * 2,
|
||||
projection_dim,
|
||||
] # Size of the transformer layers
|
||||
def mlp(x, hidden_units, dropout_rate):
|
||||
for units in hidden_units:
|
||||
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
||||
x = layers.Dropout(dropout_rate)(x)
|
||||
return x
|
||||
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
print(tf.shape(images)[1],'images')
|
||||
print(self.patch_size,'self.patch_size')
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
print(patches.shape,patch_dims,'patch_dims')
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, num_patches, projection_dim):
|
||||
super(PatchEncoder, self).__init__()
|
||||
self.num_patches = num_patches
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(
|
||||
input_dim=num_patches, output_dim=projection_dim
|
||||
)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': self.num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
def one_side_pad(x):
|
||||
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
|
||||
if IMAGE_ORDERING == 'channels_first':
|
||||
|
@ -292,3 +360,114 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-
|
|||
model = Model(img_input, o)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
bn_axis=3
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = keras.Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
num_patches = x.shape[1]*x.shape[2]
|
||||
patches = Patches(patch_size)(x)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = layers.MultiHeadAttention(
|
||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = layers.Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = layers.Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64])
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Activation('relu')(v1024_2048)
|
||||
|
||||
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = (concatenate([o, f4],axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o ,f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, inputs],axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
model = keras.Model(inputs=inputs, outputs=o)
|
||||
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue