diff --git a/train.py b/train.py index 84c9d3b..9e06a66 100644 --- a/train.py +++ b/train.py @@ -96,7 +96,7 @@ def run(_config, n_classes, n_epochs, input_height, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name): - if task == "segmentation" or task == "enhancement": + if task == "segmentation" or task == "enhancement" or task == "binarization": if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') @@ -194,16 +194,16 @@ def run(_config, n_classes, n_epochs, input_height, if continue_training: if backbone_type=='nontransformer': - if is_loss_soft_dice and task == "segmentation": + if is_loss_soft_dice and (task == "segmentation" or task == "binarization"): model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) - if weighted_loss and task == "segmentation": + if weighted_loss and (task == "segmentation" or task == "binarization"): model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True) elif backbone_type=='transformer': - if is_loss_soft_dice and task == "segmentation": + if is_loss_soft_dice and (task == "segmentation" or task == "binarization"): model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) - if weighted_loss and task == "segmentation": + if weighted_loss and (task == "segmentation" or task == "binarization"): model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) if not is_loss_soft_dice and not weighted_loss: model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) @@ -224,8 +224,9 @@ def run(_config, n_classes, n_epochs, input_height, #if you want to see the model structure just uncomment model summary. #model.summary() + - if task == "segmentation": + if (task == "segmentation" or task == "binarization"): if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=learning_rate), metrics=['accuracy']) diff --git a/utils.py b/utils.py index a2e8a9c..605d8d1 100644 --- a/utils.py +++ b/utils.py @@ -309,7 +309,7 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize img[i - c] = train_img # add to array - img[0], img[1], and so on. - if task == "segmentation": + if task == "segmentation" or task=="binarization": train_mask = cv2.imread(mask_folder + '/' + filename + '.png') train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, n_classes) @@ -569,7 +569,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow indexer = 0 for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): img_name = im.split('.')[0] - if task == "segmentation": + if task == "segmentation" or task == "binarization": dir_of_label_file = os.path.join(dir_seg, img_name + '.png') elif task=="enhancement": dir_of_label_file = os.path.join(dir_seg, im)