mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
updating train.py nontransformer backend
This commit is contained in:
parent
815e5a1d35
commit
41a0e15e79
2 changed files with 18 additions and 7 deletions
12
train.py
12
train.py
|
@ -97,8 +97,6 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
||||
|
||||
if task == "segmentation" or task == "enhancement":
|
||||
|
||||
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
|
||||
if data_is_provided:
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
@ -213,7 +211,15 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
index_start = 0
|
||||
if backbone_type=='nontransformer':
|
||||
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||
elif backbone_type=='nontransformer':
|
||||
elif backbone_type=='transformer':
|
||||
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
|
||||
|
||||
if not (num_patches == (input_width / 32) * (input_height / 32)):
|
||||
print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) )
|
||||
sys.exit(1)
|
||||
if not (transformer_patchsize == 1):
|
||||
print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" )
|
||||
sys.exit(1)
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue