mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.train: simplify transformer cfg checks
This commit is contained in:
parent
53252a59c6
commit
ee4bffd81d
1 changed files with 16 additions and 21 deletions
|
|
@ -385,31 +385,26 @@ def run(_config,
|
|||
|
||||
if transformer_cnn_first:
|
||||
model_builder = vit_resnet50_unet
|
||||
multiple_of_32 = True
|
||||
multiple = 32
|
||||
else:
|
||||
model_builder = vit_resnet50_unet_transformer_before_cnn
|
||||
multiple_of_32 = False
|
||||
multiple = 1
|
||||
|
||||
assert input_height == (num_patches_y *
|
||||
transformer_patchsize_y *
|
||||
(32 if multiple_of_32 else 1)), \
|
||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: " \
|
||||
"input_height should be equal to " \
|
||||
"(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \
|
||||
" * 32" if multiple_of_32 else ""
|
||||
assert input_width == (num_patches_x *
|
||||
transformer_patchsize_x *
|
||||
(32 if multiple_of_32 else 1)), \
|
||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: " \
|
||||
"input_width should be equal to " \
|
||||
"(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \
|
||||
" * 32" if multiple_of_32 else ""
|
||||
assert input_height == (
|
||||
num_patches_y * transformer_patchsize_y * multiple), (
|
||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
|
||||
"input_height should be equal to "
|
||||
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
|
||||
assert input_width == (
|
||||
num_patches_x * transformer_patchsize_x * multiple), (
|
||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
|
||||
"input_width should be equal to "
|
||||
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
|
||||
assert 0 == (transformer_projection_dim %
|
||||
(transformer_patchsize_y *
|
||||
transformer_patchsize_x)), \
|
||||
"transformer_projection_dim error: " \
|
||||
"The remainder when parameter transformer_projection_dim is divided by " \
|
||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero"
|
||||
(transformer_patchsize_y * transformer_patchsize_x)), (
|
||||
"transformer_projection_dim error: "
|
||||
"The remainder when parameter transformer_projection_dim is divided by "
|
||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||
|
||||
model_builder = create_captured_function(model_builder)
|
||||
model_builder.config = _config
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue