training.models: use bilinear instead of nearest upsampling…

(to benefit from CUDA optimization)
This commit is contained in:
Robert Sachunsky 2026-02-27 12:48:28 +01:00
parent ba954d6314
commit 2d5de8e595

View file

@ -212,7 +212,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta
f4 = BatchNormalization(axis=bn_axis)(f4) f4 = BatchNormalization(axis=bn_axis)(f4)
f4 = Activation('relu')(f4) f4 = Activation('relu')(f4)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
o = concatenate([o, f4], axis=MERGE_AXIS) o = concatenate([o, f4], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(512, (3, 3), padding='valid', o = Conv2D(512, (3, 3), padding='valid',
@ -220,7 +220,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta
o = BatchNormalization(axis=bn_axis)(o) o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
o = concatenate([o, f3], axis=MERGE_AXIS) o = concatenate([o, f3], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(256, (3, 3), padding='valid', o = Conv2D(256, (3, 3), padding='valid',
@ -228,7 +228,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta
o = BatchNormalization(axis=bn_axis)(o) o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
o = concatenate([o, f2], axis=MERGE_AXIS) o = concatenate([o, f2], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(128, (3, 3), padding='valid', o = Conv2D(128, (3, 3), padding='valid',
@ -236,7 +236,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta
o = BatchNormalization(axis=bn_axis)(o) o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
o = concatenate([o, f1], axis=MERGE_AXIS) o = concatenate([o, f1], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(64, (3, 3), padding='valid', o = Conv2D(64, (3, 3), padding='valid',
@ -244,7 +244,7 @@ def unet_decoder(img, f1, f2, f3, f4, f5, n_classes, light=False, task="segmenta
o = BatchNormalization(axis=bn_axis)(o) o = BatchNormalization(axis=bn_axis)(o)
o = Activation('relu')(o) o = Activation('relu')(o)
o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o) o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING, interpolation="bilinear")(o)
o = concatenate([o, img], axis=MERGE_AXIS) o = concatenate([o, img], axis=MERGE_AXIS)
o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o) o = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(o)
o = Conv2D(32, (3, 3), padding='valid', o = Conv2D(32, (3, 3), padding='valid',