tf.keras version that allows any input resolution

pull/16/head
H.T. Kruitbosch 12 months ago
parent dbb404030e
commit 3098700dc9

@ -1,22 +1,12 @@
from keras.models import * from tensorflow.keras.models import Model
from keras.layers import * from tensorflow.keras.layers import Conv2D, Concatenate, ZeroPadding2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D, Input, Layer
from keras import layers from tensorflow.keras import layers
from keras.regularizers import l2 from tensorflow.keras.regularizers import l2
import tensorflow as tf
resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
IMAGE_ORDERING ='channels_last'
MERGE_AXIS=-1
def identity_block(input_tensor, kernel_size, filters, stage, block, data_format='channels_last'):
def one_side_pad( x ):
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
if IMAGE_ORDERING == 'channels_first':
x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x)
elif IMAGE_ORDERING == 'channels_last':
x = Lambda(lambda x : x[: , :-1 , :-1 , : ] )(x)
return x
def identity_block(input_tensor, kernel_size, filters, stage, block):
"""The identity block is the block that has no conv layer at shortcut. """The identity block is the block that has no conv layer at shortcut.
# Arguments # Arguments
input_tensor: input tensor input_tensor: input tensor
@ -29,24 +19,21 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
""" """
filters1, filters2, filters3 = filters filters1, filters2, filters3 = filters
if IMAGE_ORDERING == 'channels_last': bn_axis = 3 if data_format == 'channels_last' else 1
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(input_tensor) x = Conv2D(filters1, (1, 1) , data_format=data_format , name=conv_name_base + '2a')(input_tensor)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x) x = Activation('relu')(x)
x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , x = Conv2D(filters2, kernel_size , data_format=data_format ,
padding='same', name=conv_name_base + '2b')(x) padding='same', name=conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x) x = Activation('relu')(x)
x = Conv2D(filters3 , (1, 1), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) x = Conv2D(filters3 , (1, 1), data_format=data_format , name=conv_name_base + '2c')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
x = layers.add([x, input_tensor]) x = layers.add([x, input_tensor])
@ -54,7 +41,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
return x return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2), data_format='channels_last'):
"""conv_block is the block that has a conv layer at shortcut """conv_block is the block that has a conv layer at shortcut
# Arguments # Arguments
input_tensor: input tensor input_tensor: input tensor
@ -69,28 +56,25 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
""" """
filters1, filters2, filters3 = filters filters1, filters2, filters3 = filters
if IMAGE_ORDERING == 'channels_last': bn_axis = 3 if data_format == 'channels_last' else 1
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch'
x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, x = Conv2D(filters1, (1, 1) , data_format=data_format, strides=strides,
name=conv_name_base + '2a')(input_tensor) name=conv_name_base + '2a')(input_tensor)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x) x = Activation('relu')(x)
x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , padding='same', x = Conv2D(filters2, kernel_size , data_format=data_format, padding='same',
name=conv_name_base + '2b')(x) name=conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x) x = Activation('relu')(x)
x = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) x = Conv2D(filters3, (1, 1), data_format=data_format, name=conv_name_base + '2c')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
shortcut = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, shortcut = Conv2D(filters3, (1, 1), data_format=data_format, strides=strides,
name=conv_name_base + '1')(input_tensor) name=conv_name_base + '1')(input_tensor)
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
@ -99,219 +83,131 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
return x return x
def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): class PadMultiple(Layer):
assert input_height%32 == 0 def __init__(self, mods, data_format='channels_last'):
assert input_width%32 == 0 super().__init__()
self.mods = mods
self.data_format = data_format
img_input = Input(shape=(input_height,input_width , 3 ))
if IMAGE_ORDERING == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
f2 = one_side_pad(x )
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
f5 = x
if pretraining:
model=Model( img_input , x ).load_weights(resnet50_Weights_path)
v512_2048 = Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 )
v512_2048 = ( BatchNormalization(axis=bn_axis))(v512_2048)
v512_2048 = Activation('relu')(v512_2048)
v512_1024=Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f4 )
v512_1024 = ( BatchNormalization(axis=bn_axis))(v512_1024)
v512_1024 = Activation('relu')(v512_1024)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v512_2048)
o = ( concatenate([ o ,v512_1024],axis=MERGE_AXIS ) )
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) )
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
o = ( concatenate([o,f2],axis=MERGE_AXIS ) )
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o)
o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
o = ( concatenate([o,f1],axis=MERGE_AXIS ) )
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
o = ( concatenate([o,img_input],axis=MERGE_AXIS ) )
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
def call(self, x):
h, w = self.mods
padding = (
[(0,0), (0, -tf.shape(x)[1] % h), (0, -tf.shape(x)[2] % w), (0,0)] if self.data_format == 'channels_last'
else [(0,0), (0,0), (0, -tf.shape(x)[1] % h), (0, -tf.shape(x)[2] % w)])
return tf.pad(x, padding)
o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) class CutTo(Layer):
o = ( BatchNormalization(axis=bn_axis))(o) def __init__(self, data_format='channels_last'):
o = (Activation('softmax'))(o) super().__init__()
self.data_format = data_format
def call(self, inputs):
model = Model( img_input , o ) h, w = (1, 2) if self.data_format == 'channels_last' else (2,4)
return model h, w = tf.shape(inputs[1])[h], tf.shape(inputs[1])[w]
return inputs[0][:, :h, :w] if self.data_format == 'channels_last' else inputs[0][:, :, :h, :w]
def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
assert input_height%32 == 0
assert input_width%32 == 0
def resnet50_unet(n_classes, input_height=None, input_width=None, weight_decay=1e-6, pretraining=False, last_activation='softmax', skip_last_batchnorm=False, light_version=False, data_format='channels_last'):
""" Returns a U-NET model using the keras functional API. """
img_input = Input(shape=(input_height, input_width, 3 )) img_input = Input(shape=(input_height, input_width, 3 ))
padded_to_multiple = PadMultiple((32,32))(img_input)
if IMAGE_ORDERING == 'channels_last': bn_axis = 3 if data_format == 'channels_last' else 1
bn_axis = 3 merge_axis = 3 if data_format == 'channels_last' else 1
else:
bn_axis = 1
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) x = ZeroPadding2D((3, 3), data_format=data_format)(padded_to_multiple)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) x = Conv2D(64, (7, 7), data_format=data_format, strides=(2, 2), kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x) x = Activation('relu')(x)
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) x = MaxPooling2D((3, 3) , data_format=data_format , strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), data_format=data_format)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', data_format=data_format)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', data_format=data_format)
f2 = one_side_pad(x ) f2 = ZeroPadding2D(((1,0), (1,0)), data_format=data_format)(x)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', data_format=data_format)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', data_format=data_format)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', data_format=data_format)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', data_format=data_format)
f3 = x f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', data_format=data_format)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', data_format=data_format)
f4 = x f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', data_format=data_format)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', data_format=data_format)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', data_format=data_format)
f5 = x f5 = x
if pretraining: if pretraining:
Model(img_input, x).load_weights(resnet50_Weights_path) Model(img_input, x).load_weights(resnet50_Weights_path)
v1024_2048 = Conv2D( 1024 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) if light_version:
v1024_2048 = ( BatchNormalization(axis=bn_axis))(v1024_2048) v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f5)
v1024_2048 = Activation('relu')(v1024_2048) v512_2048 = BatchNormalization(axis=bn_axis)(v512_2048)
v512_2048 = Activation('relu')(v512_2048)
v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f4)
v512_1024 = BatchNormalization(axis=bn_axis)(v512_1024)
v512_1024 = Activation('relu')(v512_1024)
x, c = v512_2048, v512_1024 # continuation and concatenation layers
else:
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(f5)
v1024_2048 = BatchNormalization(axis=bn_axis)(v1024_2048)
v1024_2048 = Activation('relu')(v1024_2048)
x, c = v1024_2048, f4 # continuation and concatenation layers
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v1024_2048) o = UpSampling2D((2,2), data_format=data_format)(x)
o = ( concatenate([ o ,f4],axis=MERGE_AXIS ) ) o = Concatenate(axis=merge_axis)([o ,c])
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) o = ZeroPadding2D( (1,1), data_format=data_format)(o)
o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) o = Conv2D(512, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = ( BatchNormalization(axis=bn_axis))(o) o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = UpSampling2D( (2,2), data_format=data_format)(o)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) o = Concatenate(axis=merge_axis)([ o ,f3])
o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) o = ZeroPadding2D( (1,1), data_format=data_format)(o)
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) o = Conv2D( 256, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) o = BatchNormalization(axis=bn_axis)(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = UpSampling2D( (2,2), data_format=data_format)(o)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) o = Concatenate(axis=merge_axis)([o,f2])
o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) o = ZeroPadding2D((1,1) , data_format=data_format)(o)
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) o = Conv2D( 128 , (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) o = BatchNormalization(axis=bn_axis)(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = UpSampling2D( (2,2), data_format=data_format)(o)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) o = Concatenate(axis=merge_axis)([o,f1])
o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) o = ZeroPadding2D((1,1) , data_format=data_format)(o)
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) o = Conv2D( 64 , (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) o = BatchNormalization(axis=bn_axis)(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = UpSampling2D( (2,2), data_format=data_format)(o)
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) o = Concatenate(axis=merge_axis)([o, padded_to_multiple])
o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) o = ZeroPadding2D((1,1) , data_format=data_format)(o)
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) o = Conv2D(32, (3, 3), padding='valid', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) o = BatchNormalization(axis=bn_axis)(o)
o = ( BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=data_format, kernel_regularizer=l2(weight_decay))(o)
if not skip_last_batchnorm:
o = BatchNormalization(axis=bn_axis)(o)
o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) o = Activation(last_activation)(o)
o = ( BatchNormalization(axis=bn_axis))(o) o = CutTo()([o, img_input])
o = (Activation('softmax'))(o)
model = Model( img_input , o )
return model return Model(img_input , o)

Loading…
Cancel
Save