mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-08 19:30:07 +02:00
code formatting with black; typos
This commit is contained in:
parent
5f84938839
commit
02b1436f39
8 changed files with 741 additions and 768 deletions
237
models.py
237
models.py
|
@ -3,19 +3,20 @@ from tensorflow.keras.layers import *
|
|||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.regularizers import l2
|
||||
|
||||
resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
IMAGE_ORDERING ='channels_last'
|
||||
MERGE_AXIS=-1
|
||||
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
MERGE_AXIS = -1
|
||||
|
||||
|
||||
def one_side_pad( x ):
|
||||
def one_side_pad(x):
|
||||
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
|
||||
if IMAGE_ORDERING == 'channels_first':
|
||||
x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x)
|
||||
x = Lambda(lambda x: x[:, :, :-1, :-1])(x)
|
||||
elif IMAGE_ORDERING == 'channels_last':
|
||||
x = Lambda(lambda x : x[: , :-1 , :-1 , : ] )(x)
|
||||
x = Lambda(lambda x: x[:, :-1, :-1, :])(x)
|
||||
return x
|
||||
|
||||
|
||||
def identity_block(input_tensor, kernel_size, filters, stage, block):
|
||||
"""The identity block is the block that has no conv layer at shortcut.
|
||||
# Arguments
|
||||
|
@ -28,7 +29,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
|
|||
Output tensor for the block.
|
||||
"""
|
||||
filters1, filters2, filters3 = filters
|
||||
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
|
@ -37,16 +38,16 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
|
|||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||
|
||||
x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(input_tensor)
|
||||
x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2a')(input_tensor)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING ,
|
||||
x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING,
|
||||
padding='same', name=conv_name_base + '2b')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
x = Conv2D(filters3 , (1, 1), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x)
|
||||
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
||||
|
||||
x = layers.add([x, input_tensor])
|
||||
|
@ -68,7 +69,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
|
|||
And the shortcut should have strides=(2,2) as well
|
||||
"""
|
||||
filters1, filters2, filters3 = filters
|
||||
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
|
@ -77,20 +78,20 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
|
|||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||
|
||||
x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , strides=strides,
|
||||
x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
|
||||
name=conv_name_base + '2a')(input_tensor)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , padding='same',
|
||||
x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same',
|
||||
name=conv_name_base + '2b')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
x = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x)
|
||||
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
||||
|
||||
shortcut = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , strides=strides,
|
||||
shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
|
||||
name=conv_name_base + '1')(input_tensor)
|
||||
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
|
||||
|
||||
|
@ -99,12 +100,11 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
|
|||
return x
|
||||
|
||||
|
||||
def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
assert input_height%32 == 0
|
||||
assert input_width%32 == 0
|
||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
||||
assert input_height % 32 == 0
|
||||
assert input_width % 32 == 0
|
||||
|
||||
|
||||
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||
img_input = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
|
@ -112,25 +112,24 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||
name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
|
||||
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x )
|
||||
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
|
@ -138,85 +137,72 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model=Model( img_input , x ).load_weights(resnet50_Weights_path)
|
||||
model = Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
|
||||
v512_2048 = Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 )
|
||||
v512_2048 = ( BatchNormalization(axis=bn_axis))(v512_2048)
|
||||
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
||||
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
|
||||
v512_2048 = Activation('relu')(v512_2048)
|
||||
|
||||
|
||||
|
||||
v512_1024=Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f4 )
|
||||
v512_1024 = ( BatchNormalization(axis=bn_axis))(v512_1024)
|
||||
v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4)
|
||||
v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024)
|
||||
v512_1024 = Activation('relu')(v512_1024)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v512_2048)
|
||||
o = ( concatenate([ o ,v512_1024],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048)
|
||||
o = (concatenate([o, v512_1024], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,f2],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,f1],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
|
||||
o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,img_input],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
|
||||
o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
|
||||
o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o )
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
|
||||
model = Model( img_input , o )
|
||||
model = Model(img_input, o)
|
||||
return model
|
||||
|
||||
def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
assert input_height%32 == 0
|
||||
assert input_width%32 == 0
|
||||
|
||||
|
||||
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||
def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
||||
assert input_height % 32 == 0
|
||||
assert input_width % 32 == 0
|
||||
|
||||
img_input = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
|
@ -224,25 +210,24 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||
name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
|
||||
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x )
|
||||
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
|
@ -250,68 +235,60 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
Model( img_input , x ).load_weights(resnet50_Weights_path)
|
||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 )
|
||||
v1024_2048 = ( BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
|
||||
f5)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Activation('relu')(v1024_2048)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = ( concatenate([ o ,f4],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = (concatenate([o, f4], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,f2],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,f1],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
|
||||
o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,img_input],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
|
||||
o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o )
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
model = Model( img_input , o )
|
||||
|
||||
|
||||
|
||||
model = Model(img_input, o)
|
||||
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue