training.models: fix daa084c3

This commit is contained in:
Robert Sachunsky 2026-02-27 12:47:59 +01:00
parent 7c3aeda65e
commit ba954d6314

View file

@ -345,9 +345,7 @@ def vit_resnet50_unet(num_patches,
transformer_num_heads, transformer_num_heads,
transformer_projection_dim) transformer_projection_dim)
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) return unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
return Model(inputs, o)
def vit_resnet50_unet_transformer_before_cnn(num_patches, def vit_resnet50_unet_transformer_before_cnn(num_patches,
n_classes, n_classes,
@ -380,9 +378,7 @@ def vit_resnet50_unet_transformer_before_cnn(num_patches,
features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining) features = resnet50(encoded_patches, weight_decay=weight_decay, pretraining=pretraining)
o = unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay) return unet_decoder(inputs, *features, n_classes, task=task, weight_decay=weight_decay)
return Model(inputs, o)
def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
include_top=True include_top=True