mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
transformer patch size is dynamic now.
This commit is contained in:
parent
2aa216e388
commit
f1fd74c7eb
3 changed files with 75 additions and 30 deletions
30
train.py
30
train.py
|
@ -70,8 +70,10 @@ def config_params():
|
|||
brightness = None # Brighten image for augmentation.
|
||||
flip_index = None # Flip image for augmentation.
|
||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
||||
transformer_patchsize = None # Patch size of vision transformer patches.
|
||||
transformer_patchsize_x = None # Patch size of vision transformer patches.
|
||||
transformer_patchsize_y = None
|
||||
transformer_num_patches_xy = None # Number of patches for vision transformer.
|
||||
transformer_projection_dim = 64 # Transformer projection dimension
|
||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||
|
@ -92,7 +94,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
brightening, binarization, blur_k, scales, degrade_scales,
|
||||
brightness, dir_train, data_is_provided, scaling_bluring,
|
||||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||
thetha, scaling_flip, continue_training, transformer_patchsize,
|
||||
thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_patchsize_x, transformer_patchsize_y,
|
||||
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
||||
|
||||
|
@ -212,15 +214,27 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
if backbone_type=='nontransformer':
|
||||
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||
elif backbone_type=='transformer':
|
||||
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
|
||||
num_patches_x = transformer_num_patches_xy[0]
|
||||
num_patches_y = transformer_num_patches_xy[1]
|
||||
num_patches = num_patches_x * num_patches_y
|
||||
|
||||
if not (num_patches == (input_width / 32) * (input_height / 32)):
|
||||
print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) )
|
||||
##if not (num_patches == (input_width / 32) * (input_height / 32)):
|
||||
##print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) )
|
||||
##sys.exit(1)
|
||||
#if not (transformer_patchsize == 1):
|
||||
#print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" )
|
||||
#sys.exit(1)
|
||||
if (input_height != (num_patches_y * transformer_patchsize_y * 32) ):
|
||||
print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)")
|
||||
sys.exit(1)
|
||||
if not (transformer_patchsize == 1):
|
||||
print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" )
|
||||
if (input_width != (num_patches_x * transformer_patchsize_x * 32) ):
|
||||
print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)")
|
||||
sys.exit(1)
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
||||
if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
|
||||
print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||
sys.exit(1)
|
||||
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue