adding enhancement training

This commit is contained in:
vahidrezanezhad 2024-05-06 18:31:48 +02:00
parent dbb84507ed
commit 38db3e9289
5 changed files with 119 additions and 68 deletions

View file

@ -168,7 +168,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
return x
def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False):
assert input_height % 32 == 0
assert input_width % 32 == 0
@ -259,14 +259,17 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_dec
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else:
o = (Activation('sigmoid'))(o)
model = Model(img_input, o)
return model
def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
assert input_height % 32 == 0
assert input_width % 32 == 0
@ -354,15 +357,18 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else:
o = (Activation('sigmoid'))(o)
model = Model(img_input, o)
return model
def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
inputs = layers.Input(shape=(input_height, input_width, 3))
IMAGE_ORDERING = 'channels_last'
bn_axis=3
@ -465,8 +471,11 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
if task == "segmentation":
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
else:
o = (Activation('sigmoid'))(o)
model = Model(inputs=inputs, outputs=o)