mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-18 08:10:04 +02:00
binarization as a separate task of segmentation
This commit is contained in:
parent
41a0e15e79
commit
2aa216e388
2 changed files with 9 additions and 8 deletions
13
train.py
13
train.py
|
@ -96,7 +96,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
||||
|
||||
if task == "segmentation" or task == "enhancement":
|
||||
if task == "segmentation" or task == "enhancement" or task == "binarization":
|
||||
if data_is_provided:
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
@ -194,16 +194,16 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
if continue_training:
|
||||
if backbone_type=='nontransformer':
|
||||
if is_loss_soft_dice and task == "segmentation":
|
||||
if is_loss_soft_dice and (task == "segmentation" or task == "binarization"):
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss and task == "segmentation":
|
||||
if weighted_loss and (task == "segmentation" or task == "binarization"):
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True)
|
||||
elif backbone_type=='transformer':
|
||||
if is_loss_soft_dice and task == "segmentation":
|
||||
if is_loss_soft_dice and (task == "segmentation" or task == "binarization"):
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss and task == "segmentation":
|
||||
if weighted_loss and (task == "segmentation" or task == "binarization"):
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
|
@ -224,8 +224,9 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
|
||||
if task == "segmentation":
|
||||
if (task == "segmentation" or task == "binarization"):
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
|
4
utils.py
4
utils.py
|
@ -309,7 +309,7 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c
|
|||
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
||||
|
||||
img[i - c] = train_img # add to array - img[0], img[1], and so on.
|
||||
if task == "segmentation":
|
||||
if task == "segmentation" or task=="binarization":
|
||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
||||
train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width,
|
||||
n_classes)
|
||||
|
@ -569,7 +569,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
|||
indexer = 0
|
||||
for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)):
|
||||
img_name = im.split('.')[0]
|
||||
if task == "segmentation":
|
||||
if task == "segmentation" or task == "binarization":
|
||||
dir_of_label_file = os.path.join(dir_seg, img_name + '.png')
|
||||
elif task=="enhancement":
|
||||
dir_of_label_file = os.path.join(dir_seg, im)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue