training.models: fix 9b66867c

This commit is contained in:
Robert Sachunsky 2026-02-27 12:40:56 +01:00
parent 439ca350dd
commit 7c3aeda65e

View file

@ -334,7 +334,7 @@ def vit_resnet50_unet(num_patches,
transformer_mlp_head_units = [128, 64] transformer_mlp_head_units = [128, 64]
inputs = Input(shape=(input_height, input_width, 3)) inputs = Input(shape=(input_height, input_width, 3))
features = resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining) features = list(resnet50(inputs, weight_decay=weight_decay, pretraining=pretraining))
features[-1] = transformer_block(features[-1], features[-1] = transformer_block(features[-1],
num_patches, num_patches,