mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-08 03:10:55 +02:00
Merge pull request #18 from johnlockejrr/unifying-training-models
Deprecations in train.py and check an argument in inference.py
This commit is contained in:
commit
d57de478eb
2 changed files with 7 additions and 6 deletions
|
@ -567,6 +567,7 @@ class sbb_predict:
|
||||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
if self.save:
|
if self.save:
|
||||||
cv2.imwrite(self.save,img_seg_overlayed)
|
cv2.imwrite(self.save,img_seg_overlayed)
|
||||||
|
if self.save_layout:
|
||||||
cv2.imwrite(self.save_layout, only_layout)
|
cv2.imwrite(self.save_layout, only_layout)
|
||||||
|
|
||||||
if self.ground_truth:
|
if self.ground_truth:
|
||||||
|
|
12
train.py
12
train.py
|
@ -278,16 +278,16 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
if (task == "segmentation" or task == "binarization"):
|
if (task == "segmentation" or task == "binarization"):
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
model.compile(loss='categorical_crossentropy',
|
model.compile(loss='categorical_crossentropy',
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||||
if is_loss_soft_dice:
|
if is_loss_soft_dice:
|
||||||
model.compile(loss=soft_dice_loss,
|
model.compile(loss=soft_dice_loss,
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||||
if weighted_loss:
|
if weighted_loss:
|
||||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||||
elif task == "enhancement":
|
elif task == "enhancement":
|
||||||
model.compile(loss='mean_squared_error',
|
model.compile(loss='mean_squared_error',
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||||
|
|
||||||
|
|
||||||
# generating train and evaluation data
|
# generating train and evaluation data
|
||||||
|
@ -300,7 +300,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
##score_best=[]
|
##score_best=[]
|
||||||
##score_best.append(0)
|
##score_best.append(0)
|
||||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||||
model.fit_generator(
|
model.fit(
|
||||||
train_gen,
|
train_gen,
|
||||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||||
validation_data=val_gen,
|
validation_data=val_gen,
|
||||||
|
@ -388,7 +388,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
|
|
||||||
#f1score_tot = [0]
|
#f1score_tot = [0]
|
||||||
indexer_start = 0
|
indexer_start = 0
|
||||||
opt = SGD(lr=0.01, momentum=0.9)
|
opt = SGD(learning_rate=0.01, momentum=0.9)
|
||||||
opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001)
|
opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001)
|
||||||
model.compile(loss="binary_crossentropy",
|
model.compile(loss="binary_crossentropy",
|
||||||
optimizer = opt_adam,metrics=['accuracy'])
|
optimizer = opt_adam,metrics=['accuracy'])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue