diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 406e937..a03f028 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -198,67 +198,82 @@ def resnet50(inputs, weight_decay=1e-6, pretraining=False): return f1, f2, f3, f4, f5 +def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmentation", weight_decay=1e-6): + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + o = Conv2D(512 if light else 1024, (1, 1), padding='same', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) + + if light: + f4 = Conv2D(512, (1, 1), padding='same', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4) + f4 = BatchNormalization(axis=bn_axis)(f4) + f4 = Activation('relu')(f4) + + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = concatenate([o, f4], axis=MERGE_AXIS) + o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) + o = Conv2D(512, (3, 3), padding='valid', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) + + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = concatenate([o, f3], axis=MERGE_AXIS) + o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) + o = Conv2D(256, (3, 3), padding='valid', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) + + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = concatenate([o, f2], axis=MERGE_AXIS) + o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) + o = Conv2D(128, (3, 3), padding='valid', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) + + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = concatenate([o, f1], axis=MERGE_AXIS) + o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) + o = Conv2D(64, (3, 3), padding='valid', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) + + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = concatenate([o, img], axis=MERGE_AXIS) + o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) + o = Conv2D(32, (3, 3), padding='valid', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) + if task == "segmentation": + o = BatchNormalization(axis=bn_axis)(o) + o = Activation('softmax')(o) + else: + o = Activation('sigmoid')(o) + + return Model(img, o) + def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): assert input_height % 32 == 0 assert input_width % 32 == 0 img_input = Input(shape=(input_height, input_width, 3)) - f1, f2, f3, f4, f5 = resnet50(img_input, weight_decay, pretraining) - - v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5) - v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048) - v512_2048 = Activation('relu')(v512_2048) - - v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4) - v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024) - v512_1024 = Activation('relu')(v512_1024) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048) - o = (concatenate([o, v512_1024], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f3], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f2], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f1], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, img_input], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) - if task == "segmentation": - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) - else: - o = (Activation('sigmoid'))(o) - - model = Model(img_input, o) - return model + features = resnet50(img_input, weight_decay=weight_decay, pretraining=pretraining) + return unet_decoder(img_input, *features, n_classes, light=True, task=task, weight_decay=weight_decay) def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False): assert input_height % 32 == 0 @@ -266,59 +281,9 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati img_input = Input(shape=(input_height, input_width, 3)) - f1, f2, f3, f4, f5 = resnet50(img_input, weight_decay, pretraining) - - v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))( - f5) - v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) - v1024_2048 = Activation('relu')(v1024_2048) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048) - o = (concatenate([o, f4], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f3], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f2], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f1], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, img_input], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o) - if task == "segmentation": - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) - else: - o = (Activation('sigmoid'))(o) - - model = Model(img_input, o) - - return model + features = resnet50(img_input, weight_decay=weight_decay, pretraining=pretraining) + return unet_decoder(img_input, *features, n_classes, light=False, task=task, weight_decay=weight_decay) def vit_resnet50_unet(num_patches, n_classes, @@ -337,9 +302,9 @@ def vit_resnet50_unet(num_patches, transformer_mlp_head_units = [128, 64] inputs = Input(shape=(input_height, input_width, 3)) - f1, f2, f3, f4, f5 = resnet50(inputs, weight_decay, pretraining) + features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining) - patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x) + patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(features[-1]) # Encode patches. encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) @@ -360,59 +325,16 @@ def vit_resnet50_unet(num_patches, encoded_patches = Add()([x3, x2]) encoded_patches = tf.reshape(encoded_patches, - [-1, x.shape[1], x.shape[2], + [-1, + features[-1].shape[1], + features[-1].shape[2], transformer_projection_dim // (transformer_patchsize_x * transformer_patchsize_y)]) + features[-1] = encoded_patches - v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) - v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) - v1024_2048 = Activation('relu')(v1024_2048) - - o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) - o = (concatenate([o, f4],axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o ,f3], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f2], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f1], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, inputs],axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) - if task == "segmentation": - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) - else: - o = (Activation('sigmoid'))(o) + o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) - model = Model(inputs=inputs, outputs=o) - - return model + return Model(inputs, o) def vit_resnet50_unet_transformer_before_cnn(num_patches, n_classes, @@ -431,11 +353,6 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, transformer_mlp_head_units = [128, 64] inputs = Input(shape=(input_height, input_width, 3)) - if IMAGE_ORDERING == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 - patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs) # Encode patches. encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches) @@ -463,59 +380,15 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches, transformer_projection_dim // (transformer_patchsize_x * transformer_patchsize_y)]) - encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches) + encoded_patches = Conv2D(3, (1, 1), padding='same', + data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), + name='convinput')(encoded_patches) - f1, f2, f3, f4, f5 = resnet50(encoded_patches, weight_decay, pretraining) + features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining) - v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(f5) - v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) - v1024_2048 = Activation('relu')(v1024_2048) + o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) - o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) - o = (concatenate([o, f4],axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o ,f3], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f2], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, f1], axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) - o = (concatenate([o, inputs],axis=MERGE_AXIS)) - o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) - o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) - o = (BatchNormalization(axis=bn_axis))(o) - o = Activation('relu')(o) - - o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) - if task == "segmentation": - o = (BatchNormalization(axis=bn_axis))(o) - o = (Activation('softmax'))(o) - else: - o = (Activation('sigmoid'))(o) - - model = Model(inputs=inputs, outputs=o) - - return model + return Model(inputs, o) def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): include_top=True