training.train: simplify transformer cfg checks

This commit is contained in:
Robert Sachunsky 2026-02-08 01:10:13 +01:00
parent 53252a59c6
commit ee4bffd81d

View file

@ -385,31 +385,26 @@ def run(_config,
if transformer_cnn_first:
model_builder = vit_resnet50_unet
multiple_of_32 = True
multiple = 32
else:
model_builder = vit_resnet50_unet_transformer_before_cnn
multiple_of_32 = False
multiple = 1
assert input_height == (num_patches_y *
transformer_patchsize_y *
(32 if multiple_of_32 else 1)), \
"transformer_patchsize_y or transformer_num_patches_xy height value error: " \
"input_height should be equal to " \
"(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \
" * 32" if multiple_of_32 else ""
assert input_width == (num_patches_x *
transformer_patchsize_x *
(32 if multiple_of_32 else 1)), \
"transformer_patchsize_x or transformer_num_patches_xy width value error: " \
"input_width should be equal to " \
"(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \
" * 32" if multiple_of_32 else ""
assert input_height == (
num_patches_y * transformer_patchsize_y * multiple), (
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
"input_height should be equal to "
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
assert input_width == (
num_patches_x * transformer_patchsize_x * multiple), (
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
"input_width should be equal to "
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
assert 0 == (transformer_projection_dim %
(transformer_patchsize_y *
transformer_patchsize_x)), \
"transformer_projection_dim error: " \
"The remainder when parameter transformer_projection_dim is divided by " \
"(transformer_patchsize_y*transformer_patchsize_x) should be zero"
(transformer_patchsize_y * transformer_patchsize_x)), (
"transformer_projection_dim error: "
"The remainder when parameter transformer_projection_dim is divided by "
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
model_builder = create_captured_function(model_builder)
model_builder.config = _config