training.models: fix glitch introduced in 3a73ccca

This commit is contained in:
Robert Sachunsky 2026-02-08 01:09:40 +01:00
parent ea285124ce
commit 53252a59c6

View file

@ -443,7 +443,6 @@ def vit_resnet50_unet(num_patches,
# Skip connection 2.
encoded_patches = Add()([x3, x2])
assert isinstance(x, Layer)
encoded_patches = tf.reshape(encoded_patches,
[-1, x.shape[1], x.shape[2],
transformer_projection_dim // (transformer_patchsize_x *