mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-07 06:59:58 +02:00
remove redundant parentheses
This commit is contained in:
parent
f2f93e0251
commit
91d2a74ac9
10 changed files with 29 additions and 29 deletions
|
@ -269,10 +269,10 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
num_patches = num_patches_x * num_patches_y
|
||||
|
||||
if transformer_cnn_first:
|
||||
if (input_height != (num_patches_y * transformer_patchsize_y * 32) ):
|
||||
if input_height != (num_patches_y * transformer_patchsize_y * 32):
|
||||
print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)")
|
||||
sys.exit(1)
|
||||
if (input_width != (num_patches_x * transformer_patchsize_x * 32) ):
|
||||
if input_width != (num_patches_x * transformer_patchsize_x * 32):
|
||||
print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)")
|
||||
sys.exit(1)
|
||||
if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
|
||||
|
@ -282,10 +282,10 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
|
||||
else:
|
||||
if (input_height != (num_patches_y * transformer_patchsize_y) ):
|
||||
if input_height != (num_patches_y * transformer_patchsize_y):
|
||||
print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y)")
|
||||
sys.exit(1)
|
||||
if (input_width != (num_patches_x * transformer_patchsize_x) ):
|
||||
if input_width != (num_patches_x * transformer_patchsize_x):
|
||||
print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x)")
|
||||
sys.exit(1)
|
||||
if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
|
||||
|
@ -297,7 +297,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
model.summary()
|
||||
|
||||
|
||||
if (task == "segmentation" or task == "binarization"):
|
||||
if task == "segmentation" or task == "binarization":
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue