mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.train: simplify transformer cfg checks
This commit is contained in:
parent
53252a59c6
commit
ee4bffd81d
1 changed files with 16 additions and 21 deletions
|
|
@ -385,31 +385,26 @@ def run(_config,
|
||||||
|
|
||||||
if transformer_cnn_first:
|
if transformer_cnn_first:
|
||||||
model_builder = vit_resnet50_unet
|
model_builder = vit_resnet50_unet
|
||||||
multiple_of_32 = True
|
multiple = 32
|
||||||
else:
|
else:
|
||||||
model_builder = vit_resnet50_unet_transformer_before_cnn
|
model_builder = vit_resnet50_unet_transformer_before_cnn
|
||||||
multiple_of_32 = False
|
multiple = 1
|
||||||
|
|
||||||
assert input_height == (num_patches_y *
|
assert input_height == (
|
||||||
transformer_patchsize_y *
|
num_patches_y * transformer_patchsize_y * multiple), (
|
||||||
(32 if multiple_of_32 else 1)), \
|
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
|
||||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: " \
|
"input_height should be equal to "
|
||||||
"input_height should be equal to " \
|
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
|
||||||
"(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \
|
assert input_width == (
|
||||||
" * 32" if multiple_of_32 else ""
|
num_patches_x * transformer_patchsize_x * multiple), (
|
||||||
assert input_width == (num_patches_x *
|
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
|
||||||
transformer_patchsize_x *
|
"input_width should be equal to "
|
||||||
(32 if multiple_of_32 else 1)), \
|
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
|
||||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: " \
|
|
||||||
"input_width should be equal to " \
|
|
||||||
"(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \
|
|
||||||
" * 32" if multiple_of_32 else ""
|
|
||||||
assert 0 == (transformer_projection_dim %
|
assert 0 == (transformer_projection_dim %
|
||||||
(transformer_patchsize_y *
|
(transformer_patchsize_y * transformer_patchsize_x)), (
|
||||||
transformer_patchsize_x)), \
|
"transformer_projection_dim error: "
|
||||||
"transformer_projection_dim error: " \
|
"The remainder when parameter transformer_projection_dim is divided by "
|
||||||
"The remainder when parameter transformer_projection_dim is divided by " \
|
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero"
|
|
||||||
|
|
||||||
model_builder = create_captured_function(model_builder)
|
model_builder = create_captured_function(model_builder)
|
||||||
model_builder.config = _config
|
model_builder.config = _config
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue