training.train: simplify transformer cfg checks

This commit is contained in:
Robert Sachunsky 2026-02-08 01:10:13 +01:00
parent 53252a59c6
commit ee4bffd81d

View file

@ -385,31 +385,26 @@ def run(_config,
if transformer_cnn_first: if transformer_cnn_first:
model_builder = vit_resnet50_unet model_builder = vit_resnet50_unet
multiple_of_32 = True multiple = 32
else: else:
model_builder = vit_resnet50_unet_transformer_before_cnn model_builder = vit_resnet50_unet_transformer_before_cnn
multiple_of_32 = False multiple = 1
assert input_height == (num_patches_y * assert input_height == (
transformer_patchsize_y * num_patches_y * transformer_patchsize_y * multiple), (
(32 if multiple_of_32 else 1)), \ "transformer_patchsize_y or transformer_num_patches_xy height value error: "
"transformer_patchsize_y or transformer_num_patches_xy height value error: " \ "input_height should be equal to "
"input_height should be equal to " \ "(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
"(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \ assert input_width == (
" * 32" if multiple_of_32 else "" num_patches_x * transformer_patchsize_x * multiple), (
assert input_width == (num_patches_x * "transformer_patchsize_x or transformer_num_patches_xy width value error: "
transformer_patchsize_x * "input_width should be equal to "
(32 if multiple_of_32 else 1)), \ "(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
"transformer_patchsize_x or transformer_num_patches_xy width value error: " \
"input_width should be equal to " \
"(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \
" * 32" if multiple_of_32 else ""
assert 0 == (transformer_projection_dim % assert 0 == (transformer_projection_dim %
(transformer_patchsize_y * (transformer_patchsize_y * transformer_patchsize_x)), (
transformer_patchsize_x)), \ "transformer_projection_dim error: "
"transformer_projection_dim error: " \ "The remainder when parameter transformer_projection_dim is divided by "
"The remainder when parameter transformer_projection_dim is divided by " \ "(transformer_patchsize_y*transformer_patchsize_x) should be zero")
"(transformer_patchsize_y*transformer_patchsize_x) should be zero"
model_builder = create_captured_function(model_builder) model_builder = create_captured_function(model_builder)
model_builder.config = _config model_builder.config = _config