diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 13a35a1..a95ba7e 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -212,7 +212,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta f4 = BatchNormalization(axis=bn_axis)(f4) f4 = Activation('relu')(f4) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, f4], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(512, (3, 3), padding='valid', @@ -220,7 +220,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta o = BatchNormalization(axis=bn_axis)(o) o = Activation('relu')(o) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, f3], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(256, (3, 3), padding='valid', @@ -228,7 +228,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta o = BatchNormalization(axis=bn_axis)(o) o = Activation('relu')(o) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, f2], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(128, (3, 3), padding='valid', @@ -236,7 +236,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta o = BatchNormalization(axis=bn_axis)(o) o = Activation('relu')(o) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, f1], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(64, (3, 3), padding='valid', @@ -244,7 +244,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta o = BatchNormalization(axis=bn_axis)(o) o = Activation('relu')(o) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, img], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(32, (3, 3), padding='valid',