diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 217ab35..ecf70b4 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -385,31 +385,26 @@ def run(_config, if transformer_cnn_first: model_builder = vit_resnet50_unet - multiple_of_32 = True + multiple = 32 else: model_builder = vit_resnet50_unet_transformer_before_cnn - multiple_of_32 = False + multiple = 1 - assert input_height == (num_patches_y * - transformer_patchsize_y * - (32 if multiple_of_32 else 1)), \ - "transformer_patchsize_y or transformer_num_patches_xy height value error: " \ - "input_height should be equal to " \ - "(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \ - " * 32" if multiple_of_32 else "" - assert input_width == (num_patches_x * - transformer_patchsize_x * - (32 if multiple_of_32 else 1)), \ - "transformer_patchsize_x or transformer_num_patches_xy width value error: " \ - "input_width should be equal to " \ - "(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \ - " * 32" if multiple_of_32 else "" + assert input_height == ( + num_patches_y * transformer_patchsize_y * multiple), ( + "transformer_patchsize_y or transformer_num_patches_xy height value error: " + "input_height should be equal to " + "(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple) + assert input_width == ( + num_patches_x * transformer_patchsize_x * multiple), ( + "transformer_patchsize_x or transformer_num_patches_xy width value error: " + "input_width should be equal to " + "(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple) assert 0 == (transformer_projection_dim % - (transformer_patchsize_y * - transformer_patchsize_x)), \ - "transformer_projection_dim error: " \ - "The remainder when parameter transformer_projection_dim is divided by " \ - "(transformer_patchsize_y*transformer_patchsize_x) should be zero" + (transformer_patchsize_y * transformer_patchsize_x)), ( + "transformer_projection_dim error: " + "The remainder when parameter transformer_projection_dim is divided by " + "(transformer_patchsize_y*transformer_patchsize_x) should be zero") model_builder = create_captured_function(model_builder) model_builder.config = _config