mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-20 01:00:02 +02:00
Fix ReduceONPlateau
wrong logic
# Training Script Improvements ## Learning Rate Management Fixes ### 1. ReduceLROnPlateau Implementation - Fixed the learning rate reduction mechanism by replacing the manual epoch loop with a single `model.fit()` call - This ensures proper tracking of validation metrics across epochs - Configured with: ```python reduce_lr = ReduceLROnPlateau( monitor='val_loss', factor=0.2, # More aggressive reduction patience=3, # Quick response to plateaus min_lr=1e-6, # Minimum learning rate min_delta=1e-5, # Minimum change to be considered improvement verbose=1 ) ``` ### 2. Warmup Implementation - Added learning rate warmup using TensorFlow's native scheduling - Gradually increases learning rate from 1e-6 to target (2e-5) over 5 epochs - Helps stabilize initial training phase - Implemented using `PolynomialDecay` schedule: ```python lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=warmup_start_lr, decay_steps=warmup_epochs * steps_per_epoch, end_learning_rate=learning_rate, power=1.0 # Linear decay ) ``` ### 3. Early Stopping - Added early stopping to prevent overfitting - Configured with: ```python early_stopping = EarlyStopping( monitor='val_loss', patience=6, restore_best_weights=True, verbose=1 ) ``` ## Model Saving Improvements ### 1. Epoch-based Model Saving - Implemented custom `ModelCheckpointWithConfig` to save both model and config - Saves after each epoch with corresponding config.json - Maintains compatibility with original script's saving behavior ### 2. Best Model Saving - Saves the best model at training end - If early stopping triggers: saves the best model from training - If no early stopping: saves the final model ## Configuration All parameters are configurable through the JSON config file: ```json { "reduce_lr_enabled": true, "reduce_lr_monitor": "val_loss", "reduce_lr_factor": 0.2, "reduce_lr_patience": 3, "reduce_lr_min_lr": 1e-6, "reduce_lr_min_delta": 1e-5, "early_stopping_enabled": true, "early_stopping_monitor": "val_loss", "early_stopping_patience": 6, "early_stopping_restore_best_weights": true, "warmup_enabled": true, "warmup_epochs": 5, "warmup_start_lr": 1e-6 } ``` ## Benefits 1. More stable training with proper learning rate management 2. Better handling of training plateaus 3. Automatic saving of best model 4. Maintained compatibility with existing config saving 5. Improved training monitoring and control
This commit is contained in:
parent
7661080899
commit
f298643fcf
2 changed files with 44 additions and 41 deletions
78
train.py
78
train.py
|
@ -5,7 +5,7 @@ import tensorflow as tf
|
|||
from tensorflow.compat.v1.keras.backend import set_session
|
||||
import warnings
|
||||
from tensorflow.keras.optimizers import *
|
||||
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback
|
||||
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback, ModelCheckpoint
|
||||
from sacred import Experiment
|
||||
from models import *
|
||||
from utils import *
|
||||
|
@ -30,22 +30,6 @@ def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch):
|
|||
return lr_schedule
|
||||
|
||||
|
||||
class WarmupScheduler(Callback):
|
||||
def __init__(self, start_lr, target_lr, warmup_epochs):
|
||||
super(WarmupScheduler, self).__init__()
|
||||
self.start_lr = start_lr
|
||||
self.target_lr = target_lr
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
if self.current_epoch < self.warmup_epochs:
|
||||
# Linear warmup
|
||||
lr = self.start_lr + (self.target_lr - self.start_lr) * (self.current_epoch / self.warmup_epochs)
|
||||
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
|
||||
self.current_epoch += 1
|
||||
|
||||
|
||||
def configuration():
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
@ -133,6 +117,7 @@ def config_params():
|
|||
reduce_lr_factor = 0.5 # Factor to reduce learning rate by
|
||||
reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate
|
||||
reduce_lr_min_lr = 1e-6 # Minimum learning rate
|
||||
reduce_lr_min_delta = 1e-5 # Minimum change in monitored value to be considered as improvement
|
||||
early_stopping_enabled = False # Whether to use EarlyStopping callback
|
||||
early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping
|
||||
early_stopping_patience = 10 # Number of epochs to wait before stopping
|
||||
|
@ -156,7 +141,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
transformer_patchsize_x, transformer_patchsize_y,
|
||||
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds,
|
||||
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr,
|
||||
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, reduce_lr_min_delta,
|
||||
early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights,
|
||||
warmup_enabled, warmup_epochs, warmup_start_lr):
|
||||
|
||||
|
@ -328,6 +313,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
factor=reduce_lr_factor,
|
||||
patience=reduce_lr_patience,
|
||||
min_lr=reduce_lr_min_lr,
|
||||
min_delta=reduce_lr_min_delta,
|
||||
verbose=1
|
||||
)
|
||||
callbacks.append(reduce_lr)
|
||||
|
@ -341,6 +327,27 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
# Add checkpoint to save models every epoch
|
||||
class ModelCheckpointWithConfig(ModelCheckpoint):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._config = _config
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
super().on_epoch_end(epoch, logs)
|
||||
model_dir = os.path.join(dir_output, f"model_{epoch+1}")
|
||||
with open(os.path.join(model_dir, "config.json"), "w") as fp:
|
||||
json.dump(self._config, fp)
|
||||
|
||||
checkpoint_epoch = ModelCheckpointWithConfig(
|
||||
os.path.join(dir_output, "model_{epoch}"),
|
||||
save_freq='epoch',
|
||||
save_weights_only=False,
|
||||
save_best_only=False,
|
||||
verbose=1
|
||||
)
|
||||
callbacks.append(checkpoint_epoch)
|
||||
|
||||
# Calculate steps per epoch
|
||||
steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1
|
||||
|
||||
|
@ -376,27 +383,22 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
||||
|
||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1,
|
||||
callbacks=callbacks)
|
||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||
# Single fit call with all epochs
|
||||
history = model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=n_epochs,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
#os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
# Save the best model (either from early stopping or final model)
|
||||
model.save(os.path.join(dir_output, 'model_best'))
|
||||
|
||||
with open(os.path.join(dir_output, 'model_best', "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
elif task=='classification':
|
||||
configuration()
|
||||
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
"input_width" : 448,
|
||||
"weight_decay" : 1e-4,
|
||||
"n_batch" : 4,
|
||||
"learning_rate": 5e-5,
|
||||
"learning_rate": 2e-5,
|
||||
"patches" : false,
|
||||
"pretraining" : true,
|
||||
"augmentation" : true,
|
||||
|
@ -39,8 +39,9 @@
|
|||
"dir_output": "runs/sam_41_mss_npt_448x448",
|
||||
"reduce_lr_enabled": true,
|
||||
"reduce_lr_monitor": "val_loss",
|
||||
"reduce_lr_factor": 0.5,
|
||||
"reduce_lr_patience": 4,
|
||||
"reduce_lr_factor": 0.2,
|
||||
"reduce_lr_patience": 3,
|
||||
"reduce_lr_min_delta": 1e-5,
|
||||
"reduce_lr_min_lr": 1e-6,
|
||||
"early_stopping_enabled": true,
|
||||
"early_stopping_monitor": "val_loss",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue