mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-07 19:05:24 +02:00
Merge 451188c3b9
into 872e5b0b3a
This commit is contained in:
commit
0261225610
2 changed files with 7 additions and 6 deletions
|
@ -567,6 +567,7 @@ class sbb_predict:
|
|||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||
if self.save:
|
||||
cv2.imwrite(self.save,img_seg_overlayed)
|
||||
if self.save_layout:
|
||||
cv2.imwrite(self.save_layout, only_layout)
|
||||
|
||||
if self.ground_truth:
|
||||
|
|
12
train.py
12
train.py
|
@ -278,16 +278,16 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
if (task == "segmentation" or task == "binarization"):
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
elif task == "enhancement":
|
||||
model.compile(loss='mean_squared_error',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
|
||||
|
||||
# generating train and evaluation data
|
||||
|
@ -300,7 +300,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit_generator(
|
||||
model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||
validation_data=val_gen,
|
||||
|
@ -385,7 +385,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
#f1score_tot = [0]
|
||||
indexer_start = 0
|
||||
opt = SGD(lr=0.01, momentum=0.9)
|
||||
opt = SGD(learning_rate=0.01, momentum=0.9)
|
||||
opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001)
|
||||
model.compile(loss="binary_crossentropy",
|
||||
optimizer = opt_adam,metrics=['accuracy'])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue