Changed deprecated `lr` to `learning_rate` and `model.fit_generator` to `model.fit`

pull/18/head
johnlockejrr 2 months ago committed by GitHub
parent df4a47ae6f
commit 451188c3b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -277,16 +277,16 @@ def run(_config, n_classes, n_epochs, input_height,
if (task == "segmentation" or task == "binarization"): if (task == "segmentation" or task == "binarization"):
if not is_loss_soft_dice and not weighted_loss: if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=learning_rate), metrics=['accuracy']) optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
if is_loss_soft_dice: if is_loss_soft_dice:
model.compile(loss=soft_dice_loss, model.compile(loss=soft_dice_loss,
optimizer=Adam(lr=learning_rate), metrics=['accuracy']) optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
if weighted_loss: if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights), model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=Adam(lr=learning_rate), metrics=['accuracy']) optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
elif task == "enhancement": elif task == "enhancement":
model.compile(loss='mean_squared_error', model.compile(loss='mean_squared_error',
optimizer=Adam(lr=learning_rate), metrics=['accuracy']) optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
# generating train and evaluation data # generating train and evaluation data
@ -299,7 +299,7 @@ def run(_config, n_classes, n_epochs, input_height,
##score_best=[] ##score_best=[]
##score_best.append(0) ##score_best.append(0)
for i in tqdm(range(index_start, n_epochs + index_start)): for i in tqdm(range(index_start, n_epochs + index_start)):
model.fit_generator( model.fit(
train_gen, train_gen,
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
validation_data=val_gen, validation_data=val_gen,
@ -384,7 +384,7 @@ def run(_config, n_classes, n_epochs, input_height,
#f1score_tot = [0] #f1score_tot = [0]
indexer_start = 0 indexer_start = 0
opt = SGD(lr=0.01, momentum=0.9) opt = SGD(learning_rate=0.01, momentum=0.9)
opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001) opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001)
model.compile(loss="binary_crossentropy", model.compile(loss="binary_crossentropy",
optimizer = opt_adam,metrics=['accuracy']) optimizer = opt_adam,metrics=['accuracy'])

Loading…
Cancel
Save