You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
sbb_pixelwise_segmentation/models.py

214 lines
9.9 KiB
Python

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Concatenate, ZeroPadding2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D, Input, Layer
from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2
import tensorflow as tf
resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
def identity_block(input_tensor, kernel_size, filters, stage, block, data_format='channels_last'):
"""The identity block is the block that has no conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: defualt 3, the kernel size of middle conv layer at main path
filters: list of integers, the filterss of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Output tensor for the block.
"""
filters1, filters2, filters3 = filters
bn_axis = 3 if data_format == 'channels_last' else 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = Conv2D(filters1, (1, 1) , data_format=data_format , name=conv_name_base + '2a')(input_tensor)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x)
x = Conv2D(filters2, kernel_size , data_format=data_format ,
padding='same', name=conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)
x = Conv2D(filters3 , (1, 1), data_format=data_format , name=conv_name_base + '2c')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
x = layers.add([x, input_tensor])
x = Activation('relu')(x)
return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2), data_format='channels_last'):
"""conv_block is the block that has a conv layer at shortcut
# Arguments
input_tensor: input tensor
kernel_size: defualt 3, the kernel size of middle conv layer at main path
filters: list of integers, the filterss of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Output tensor for the block.
Note that from stage 3, the first conv layer at main path is with strides=(2,2)
And the shortcut should have strides=(2,2) as well
"""
filters1, filters2, filters3 = filters
bn_axis = 3 if data_format == 'channels_last' else 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = Conv2D(filters1, (1, 1) , data_format=data_format, strides=strides,
name=conv_name_base + '2a')(input_tensor)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x)
x = Conv2D(filters2, kernel_size , data_format=data_format, padding='same',
name=conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)
x = Conv2D(filters3, (1, 1), data_format=data_format, name=conv_name_base + '2c')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
shortcut = Conv2D(filters3, (1, 1), data_format=data_format, strides=strides,
name=conv_name_base + '1')(input_tensor)
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
x = layers.add([x, shortcut])
x = Activation('relu')(x)
return x
class PadMultiple(Layer):
def __init__(self, mods, data_format='channels_last'):
super().__init__()
self.mods = mods
self.data_format = data_format
def call(self, x):
h, w = self.mods
padding = (
[(0,0), (0, -tf.shape(x)[1] % h), (0, -tf.shape(x)[2] % w), (0,0)] if self.data_format == 'channels_last'
else [(0,0), (0,0), (0, -tf.shape(x)[1] % h), (0, -tf.shape(x)[2] % w)])
return tf.pad(x, padding)
class CutTo(Layer):
def __init__(self, data_format='channels_last'):
super().__init__()
self.data_format = data_format
def call(self, inputs):
h, w = (1, 2) if self.data_format == 'channels_last' else (2,4)
h, w = tf.shape(inputs[1])[h], tf.shape(inputs[1])[w]
return inputs[0][:, :h, :w] if self.data_format == 'channels_last' else inputs[0][:, :, :h, :w]
def resnet50_unet(n_classes, input_height=None, input_width=None, weight_decay=1e-6, pretraining=False, last_activation='softmax', skip_last_batchnorm=False, light_version=False, data_format='channels_last'):
""" Returns a U-NET model using the keras functional API. """
img_input = Input(shape=(input_height, input_width, 3 ))
padded_to_multiple = PadMultiple((32,32))(img_input)
bn_axis = 3 if data_format == 'channels_last' else 1
merge_axis = 3 if data_format == 'channels_last' else 1
x = ZeroPadding2D((3, 3), data_format=data_format)(padded_to_multiple)
x = Conv2D(64, (7, 7), data_format=data_format, strides=(2, 2), kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3) , data_format=data_format , strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), data_format=data_format)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', data_format=data_format)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', data_format=data_format)
f2 = ZeroPadding2D(((1,0), (1,0)), data_format=data_format)(x)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', data_format=data_format)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', data_format=data_format)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', data_format=data_format)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', data_format=data_format)
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', data_format=data_format)
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', data_format=data_format)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', data_format=data_format)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', data_format=data_format)
f5 = x
if pretraining:
Model(img_input, x).load_weights(resnet50_Weights_path)
if light_version:
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f5)
v512_2048 = BatchNormalization(axis=bn_axis)(v512_2048)
v512_2048 = Activation('relu')(v512_2048)
v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f4)
v512_1024 = BatchNormalization(axis=bn_axis)(v512_1024)
v512_1024 = Activation('relu')(v512_1024)
x, c = v512_2048, v512_1024 # continuation and concatenation layers
else:
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f5)
v1024_2048 = BatchNormalization(axis=bn_axis)(v1024_2048)
v1024_2048 = Activation('relu')(v1024_2048)
x, c = v1024_2048, f4 # continuation and concatenation layers
o = UpSampling2D((2,2), data_format=data_format)(x)
o = Concatenate(axis=merge_axis)([o ,c])
o = ZeroPadding2D( (1,1), data_format=data_format)(o)
o = Conv2D(512, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = UpSampling2D( (2,2), data_format=data_format)(o)
o = Concatenate(axis=merge_axis)([ o ,f3])
o = ZeroPadding2D( (1,1), data_format=data_format)(o)
o = Conv2D( 256, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = UpSampling2D( (2,2), data_format=data_format)(o)
o = Concatenate(axis=merge_axis)([o,f2])
o = ZeroPadding2D((1,1) , data_format=data_format)(o)
o = Conv2D( 128 , (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = UpSampling2D( (2,2), data_format=data_format)(o)
o = Concatenate(axis=merge_axis)([o,f1])
o = ZeroPadding2D((1,1) , data_format=data_format)(o)
o = Conv2D( 64 , (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = UpSampling2D( (2,2), data_format=data_format)(o)
o = Concatenate(axis=merge_axis)([o, padded_to_multiple])
o = ZeroPadding2D((1,1) , data_format=data_format)(o)
o = Conv2D(32, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
if not skip_last_batchnorm:
o = BatchNormalization(axis=bn_axis)(o)
o = Activation(last_activation)(o)
o = CutTo()([o, img_input])
return Model(img_input , o)