diff --git a/inference.py b/inference.py index 2b12ff7..db3b31f 100644 --- a/inference.py +++ b/inference.py @@ -567,6 +567,7 @@ class sbb_predict: img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) if self.save: cv2.imwrite(self.save,img_seg_overlayed) + if self.save_layout: cv2.imwrite(self.save_layout, only_layout) if self.ground_truth: diff --git a/train.py b/train.py index 7e3e390..4d9d8cb 100644 --- a/train.py +++ b/train.py @@ -278,16 +278,16 @@ def run(_config, n_classes, n_epochs, input_height, if (task == "segmentation" or task == "binarization"): if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) if is_loss_soft_dice: model.compile(loss=soft_dice_loss, - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) if weighted_loss: model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) elif task == "enhancement": model.compile(loss='mean_squared_error', - optimizer=Adam(lr=learning_rate), metrics=['accuracy']) + optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) # generating train and evaluation data @@ -300,7 +300,7 @@ def run(_config, n_classes, n_epochs, input_height, ##score_best=[] ##score_best.append(0) for i in tqdm(range(index_start, n_epochs + index_start)): - model.fit_generator( + model.fit( train_gen, steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, validation_data=val_gen, @@ -385,7 +385,7 @@ def run(_config, n_classes, n_epochs, input_height, #f1score_tot = [0] indexer_start = 0 - opt = SGD(lr=0.01, momentum=0.9) + opt = SGD(learning_rate=0.01, momentum=0.9) opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) model.compile(loss="binary_crossentropy", optimizer = opt_adam,metrics=['accuracy'])