From 2d5de8e5957d5ab6540cad7bd350f6b99ca49cc5 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Fri, 27 Feb 2026 12:48:28 +0100 Subject: [PATCH] =?UTF-8?q?training.models:=20use=20bilinear=20instead=20o?= =?UTF-8?q?f=20nearest=20upsampling=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (to benefit from CUDA optimization) --- src/eynollah/training/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 13a35a1..a95ba7e 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -212,7 +212,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta f4 = BatchNormalization(axis=bn_axis)(f4) f4 = Activation('relu')(f4) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, f4], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(512, (3, 3), padding='valid', @@ -220,7 +220,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta o = BatchNormalization(axis=bn_axis)(o) o = Activation('relu')(o) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, f3], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(256, (3, 3), padding='valid', @@ -228,7 +228,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta o = BatchNormalization(axis=bn_axis)(o) o = Activation('relu')(o) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, f2], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(128, (3, 3), padding='valid', @@ -236,7 +236,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta o = BatchNormalization(axis=bn_axis)(o) o = Activation('relu')(o) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, f1], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(64, (3, 3), padding='valid', @@ -244,7 +244,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta o = BatchNormalization(axis=bn_axis)(o) o = Activation('relu')(o) - o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) + o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o) o = concatenate([o, img], axis=MERGE_AXIS) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = Conv2D(32, (3, 3), padding='valid',