Merge remote-tracking branch 'pixelwise_local/ReduceLROnPlateau' into ReduceLROnPlateau

# Conflicts:
#	LICENSE
#	README.md
#	requirements.txt
#	train.py
This commit is contained in:
kba 2025-10-16 20:20:06 +02:00
commit 54132a499a
4 changed files with 174 additions and 37 deletions

158
train.py
View file

@ -31,6 +31,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.compat.v1.keras.backend import set_session
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback, ModelCheckpoint
from sacred import Experiment
from tensorflow.keras.models import load_model
from tqdm import tqdm
@ -61,6 +62,20 @@ class SaveWeightsAfterSteps(Callback):
json.dump(self._config, fp) # encode dict into JSON
print(f"saved model as steps {self.step_count} to {save_file}")
def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch):
initial_learning_rate = start_lr
target_learning_rate = target_lr
warmup_steps = warmup_epochs * steps_per_epoch
lr_schedule = tf.keras.optimizers.schedules.LinearSchedule(
initial_learning_rate=initial_learning_rate,
final_learning_rate=target_learning_rate,
total_steps=warmup_steps
)
return lr_schedule
def configuration():
config = tf.compat.v1.ConfigProto()
@ -80,7 +95,6 @@ def get_dirs_or_files(input_data):
ex = Experiment(save_git_info=False)
@ex.config
def config_params():
n_classes = None # Number of classes. In the case of binary classification this should be 2.
@ -145,6 +159,19 @@ def config_params():
number_of_backgrounds_per_image = 1
dir_rgb_backgrounds = None
dir_rgb_foregrounds = None
reduce_lr_enabled = False # Whether to use ReduceLROnPlateau callback
reduce_lr_monitor = 'val_loss' # Metric to monitor for reducing learning rate
reduce_lr_factor = 0.5 # Factor to reduce learning rate by
reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate
reduce_lr_min_lr = 1e-6 # Minimum learning rate
reduce_lr_min_delta = 1e-5 # Minimum change in monitored value to be considered as improvement
early_stopping_enabled = False # Whether to use EarlyStopping callback
early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping
early_stopping_patience = 10 # Number of epochs to wait before stopping
early_stopping_restore_best_weights = True # Whether to restore best weights when stopping
warmup_enabled = False # Whether to use learning rate warmup
warmup_epochs = 5 # Number of epochs for warmup
warmup_start_lr = 1e-6 # Starting learning rate for warmup
@ex.automain
def run(_config, n_classes, n_epochs, input_height,
@ -159,7 +186,10 @@ def run(_config, n_classes, n_epochs, input_height,
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
transformer_patchsize_x, transformer_patchsize_y,
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds,
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, reduce_lr_min_delta,
early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights,
warmup_enabled, warmup_epochs, warmup_start_lr):
if dir_rgb_backgrounds:
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
@ -320,20 +350,91 @@ def run(_config, n_classes, n_epochs, input_height,
#if you want to see the model structure just uncomment model summary.
model.summary()
# Create callbacks list
callbacks = []
if reduce_lr_enabled:
reduce_lr = ReduceLROnPlateau(
monitor=reduce_lr_monitor,
factor=reduce_lr_factor,
patience=reduce_lr_patience,
min_lr=reduce_lr_min_lr,
min_delta=reduce_lr_min_delta,
verbose=1
)
callbacks.append(reduce_lr)
if early_stopping_enabled:
early_stopping = EarlyStopping(
monitor=early_stopping_monitor,
patience=early_stopping_patience,
restore_best_weights=early_stopping_restore_best_weights,
verbose=1
)
callbacks.append(early_stopping)
# Add checkpoint to save models every epoch
class ModelCheckpointWithConfig(ModelCheckpoint):
def __init__(self, *args, **kwargs):
self._config = _config
super().__init__(*args, **kwargs)
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
model_dir = os.path.join(dir_output, f"model_{epoch+1}")
with open(os.path.join(model_dir, "config.json"), "w") as fp:
json.dump(self._config, fp)
checkpoint_epoch = ModelCheckpointWithConfig(
os.path.join(dir_output, "model_{epoch}"),
save_freq='epoch',
save_weights_only=False,
save_best_only=False,
verbose=1
)
callbacks.append(checkpoint_epoch)
# Calculate steps per epoch
steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1
# Create optimizer with or without warmup
if warmup_enabled:
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=warmup_start_lr,
decay_steps=warmup_epochs * steps_per_epoch,
end_learning_rate=learning_rate,
power=1.0 # Linear decay
)
optimizer = Adam(learning_rate=lr_schedule)
else:
optimizer = Adam(learning_rate=learning_rate)
if (task == "segmentation" or task == "binarization"):
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer=optimizer, metrics=['accuracy'])
if is_loss_soft_dice:
model.compile(loss=soft_dice_loss,
optimizer=optimizer, metrics=['accuracy'])
if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=optimizer, metrics=['accuracy'])
elif task == "enhancement":
model.compile(loss='mean_squared_error',
optimizer=optimizer, metrics=['accuracy'])
if task == "segmentation" or task == "binarization":
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
if is_loss_soft_dice:
model.compile(loss=soft_dice_loss,
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
elif task == "enhancement":
model.compile(loss='mean_squared_error',
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
# generating train and evaluation data
@ -342,39 +443,22 @@ def run(_config, n_classes, n_epochs, input_height,
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
##score_best=[]
##score_best.append(0)
# Single fit call with all epochs
history = model.fit(
train_gen,
steps_per_epoch=steps_per_epoch,
validation_data=val_gen,
validation_steps=1,
epochs=n_epochs,
callbacks=callbacks
)
if save_interval:
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval:
model.fit(
train_gen,
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
validation_data=val_gen,
validation_steps=1,
epochs=1, callbacks=[save_weights_callback])
else:
model.fit(
train_gen,
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
validation_data=val_gen,
validation_steps=1,
epochs=1)
model.save(os.path.join(dir_output,'model_'+str(i)))
# Save the best model (either from early stopping or final model)
model.save(os.path.join(dir_output, 'model_best'))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
#os.system('rm -rf '+dir_train_flowing)
#os.system('rm -rf '+dir_eval_flowing)
#model.save(dir_output+'/'+'model'+'.h5')
with open(os.path.join(dir_output, 'model_best', "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
elif task=='classification':
configuration()
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)