mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
Merge pull request #18 from johnlockejrr/unifying-training-models
Deprecations in train.py and check an argument in inference.py
This commit is contained in:
commit
d57de478eb
2 changed files with 7 additions and 6 deletions
12
train.py
12
train.py
|
@ -278,16 +278,16 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
if (task == "segmentation" or task == "binarization"):
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
elif task == "enhancement":
|
||||
model.compile(loss='mean_squared_error',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
|
||||
|
||||
# generating train and evaluation data
|
||||
|
@ -300,7 +300,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit_generator(
|
||||
model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||
validation_data=val_gen,
|
||||
|
@ -388,7 +388,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
#f1score_tot = [0]
|
||||
indexer_start = 0
|
||||
opt = SGD(lr=0.01, momentum=0.9)
|
||||
opt = SGD(learning_rate=0.01, momentum=0.9)
|
||||
opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001)
|
||||
model.compile(loss="binary_crossentropy",
|
||||
optimizer = opt_adam,metrics=['accuracy'])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue