mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-08 19:30:07 +02:00
adding enhancement training
This commit is contained in:
parent
dbb84507ed
commit
38db3e9289
5 changed files with 119 additions and 68 deletions
27
models.py
27
models.py
|
@ -168,7 +168,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
|
|||
return x
|
||||
|
||||
|
||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
assert input_height % 32 == 0
|
||||
assert input_width % 32 == 0
|
||||
|
||||
|
@ -259,14 +259,17 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_dec
|
|||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
if task == "segmentation":
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
else:
|
||||
o = (Activation('sigmoid'))(o)
|
||||
|
||||
model = Model(img_input, o)
|
||||
return model
|
||||
|
||||
|
||||
def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
||||
def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
assert input_height % 32 == 0
|
||||
assert input_width % 32 == 0
|
||||
|
||||
|
@ -354,15 +357,18 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-
|
|||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
if task == "segmentation":
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
else:
|
||||
o = (Activation('sigmoid'))(o)
|
||||
|
||||
model = Model(img_input, o)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
bn_axis=3
|
||||
|
@ -465,8 +471,11 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_
|
|||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
if task == "segmentation":
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
else:
|
||||
o = (Activation('sigmoid'))(o)
|
||||
|
||||
model = Model(inputs=inputs, outputs=o)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue