From 766108089983e787870fc6d7dc4f9bbde7bd5b75 Mon Sep 17 00:00:00 2001 From: johnlockejrr <16368414+johnlockejrr@users.noreply.github.com> Date: Sat, 17 May 2025 16:17:38 +0300 Subject: [PATCH] LR Warmup and Optimization Implementation # Learning Rate Warmup and Optimization Implementation ## Overview Added learning rate warmup functionality to improve training stability, especially when using pretrained weights. The implementation uses TensorFlow's native learning rate scheduling for better performance. ## Changes Made ### 1. Configuration Updates (`runs/train_no_patches_448x448.json`) Added new configuration parameters for warmup: ```json { "warmup_enabled": true, "warmup_epochs": 5, "warmup_start_lr": 1e-6 } ``` ### 2. Training Script Updates (`train.py`) #### A. Optimizer and Learning Rate Schedule - Replaced fixed learning rate with dynamic scheduling - Implemented warmup using `tf.keras.optimizers.schedules.PolynomialDecay` - Maintained compatibility with existing ReduceLROnPlateau and EarlyStopping ```python if warmup_enabled: lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=warmup_start_lr, decay_steps=warmup_epochs * steps_per_epoch, end_learning_rate=learning_rate, power=1.0 # Linear decay ) optimizer = Adam(learning_rate=lr_schedule) else: optimizer = Adam(learning_rate=learning_rate) ``` #### B. Learning Rate Behavior - Initial learning rate: 1e-6 (configurable via `warmup_start_lr`) - Target learning rate: 5e-5 (configurable via `learning_rate`) - Linear increase over 5 epochs (configurable via `warmup_epochs`) - After warmup, learning rate remains at target value until ReduceLROnPlateau triggers ## Benefits 1. Improved training stability during initial epochs 2. Better handling of pretrained weights 3. Efficient implementation using TensorFlow's native scheduling 4. Configurable through JSON configuration file 5. Maintains compatibility with existing callbacks (ReduceLROnPlateau, EarlyStopping) ## Usage To enable warmup: 1. Set `warmup_enabled: true` in the configuration file 2. Adjust `warmup_epochs` and `warmup_start_lr` as needed 3. The warmup will automatically integrate with existing learning rate reduction and early stopping To disable warmup: - Set `warmup_enabled: false` or remove the warmup parameters from the configuration file --- train.py | 99 ++++++++++++++++++++++++++++++++--- train_no_patches_448x448.json | 52 ++++++++++++++++++ 2 files changed, 143 insertions(+), 8 deletions(-) create mode 100644 train_no_patches_448x448.json diff --git a/train.py b/train.py index 4cc3cbb..668d6aa 100644 --- a/train.py +++ b/train.py @@ -5,6 +5,7 @@ import tensorflow as tf from tensorflow.compat.v1.keras.backend import set_session import warnings from tensorflow.keras.optimizers import * +from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback from sacred import Experiment from models import * from utils import * @@ -15,6 +16,36 @@ import json from sklearn.metrics import f1_score +def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch): + initial_learning_rate = start_lr + target_learning_rate = target_lr + warmup_steps = warmup_epochs * steps_per_epoch + + lr_schedule = tf.keras.optimizers.schedules.LinearSchedule( + initial_learning_rate=initial_learning_rate, + final_learning_rate=target_learning_rate, + total_steps=warmup_steps + ) + + return lr_schedule + + +class WarmupScheduler(Callback): + def __init__(self, start_lr, target_lr, warmup_epochs): + super(WarmupScheduler, self).__init__() + self.start_lr = start_lr + self.target_lr = target_lr + self.warmup_epochs = warmup_epochs + self.current_epoch = 0 + + def on_epoch_begin(self, epoch, logs=None): + if self.current_epoch < self.warmup_epochs: + # Linear warmup + lr = self.start_lr + (self.target_lr - self.start_lr) * (self.current_epoch / self.warmup_epochs) + tf.keras.backend.set_value(self.model.optimizer.lr, lr) + self.current_epoch += 1 + + def configuration(): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True @@ -97,6 +128,18 @@ def config_params(): number_of_backgrounds_per_image = 1 dir_rgb_backgrounds = None dir_rgb_foregrounds = None + reduce_lr_enabled = False # Whether to use ReduceLROnPlateau callback + reduce_lr_monitor = 'val_loss' # Metric to monitor for reducing learning rate + reduce_lr_factor = 0.5 # Factor to reduce learning rate by + reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate + reduce_lr_min_lr = 1e-6 # Minimum learning rate + early_stopping_enabled = False # Whether to use EarlyStopping callback + early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping + early_stopping_patience = 10 # Number of epochs to wait before stopping + early_stopping_restore_best_weights = True # Whether to restore best weights when stopping + warmup_enabled = False # Whether to use learning rate warmup + warmup_epochs = 5 # Number of epochs for warmup + warmup_start_lr = 1e-6 # Starting learning rate for warmup @ex.automain @@ -112,7 +155,10 @@ def run(_config, n_classes, n_epochs, input_height, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, - pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds, + reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, + early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights, + warmup_enabled, warmup_epochs, warmup_start_lr): if dir_rgb_backgrounds: list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) @@ -274,20 +320,55 @@ def run(_config, n_classes, n_epochs, input_height, model.summary() + # Create callbacks list + callbacks = [] + if reduce_lr_enabled: + reduce_lr = ReduceLROnPlateau( + monitor=reduce_lr_monitor, + factor=reduce_lr_factor, + patience=reduce_lr_patience, + min_lr=reduce_lr_min_lr, + verbose=1 + ) + callbacks.append(reduce_lr) + + if early_stopping_enabled: + early_stopping = EarlyStopping( + monitor=early_stopping_monitor, + patience=early_stopping_patience, + restore_best_weights=early_stopping_restore_best_weights, + verbose=1 + ) + callbacks.append(early_stopping) + + # Calculate steps per epoch + steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1 + + # Create optimizer with or without warmup + if warmup_enabled: + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=warmup_start_lr, + decay_steps=warmup_epochs * steps_per_epoch, + end_learning_rate=learning_rate, + power=1.0 # Linear decay + ) + optimizer = Adam(learning_rate=lr_schedule) + else: + optimizer = Adam(learning_rate=learning_rate) + if (task == "segmentation" or task == "binarization"): if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) if is_loss_soft_dice: model.compile(loss=soft_dice_loss, - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) if weighted_loss: model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) elif task == "enhancement": model.compile(loss='mean_squared_error', - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) - + optimizer=optimizer, metrics=['accuracy']) # generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, @@ -298,13 +379,15 @@ def run(_config, n_classes, n_epochs, input_height, ##img_validation_patches = os.listdir(dir_flow_eval_imgs) ##score_best=[] ##score_best.append(0) + for i in tqdm(range(index_start, n_epochs + index_start)): model.fit( train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, + steps_per_epoch=steps_per_epoch, validation_data=val_gen, validation_steps=1, - epochs=1) + epochs=1, + callbacks=callbacks) model.save(os.path.join(dir_output,'model_'+str(i))) with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: diff --git a/train_no_patches_448x448.json b/train_no_patches_448x448.json new file mode 100644 index 0000000..c3d0e10 --- /dev/null +++ b/train_no_patches_448x448.json @@ -0,0 +1,52 @@ +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 3, + "n_epochs" : 50, + "input_height" : 448, + "input_width" : 448, + "weight_decay" : 1e-4, + "n_batch" : 4, + "learning_rate": 5e-5, + "patches" : false, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": true, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": true, + "data_is_provided": true, + "dir_train": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/train", + "dir_eval": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/eval", + "dir_output": "runs/sam_41_mss_npt_448x448", + "reduce_lr_enabled": true, + "reduce_lr_monitor": "val_loss", + "reduce_lr_factor": 0.5, + "reduce_lr_patience": 4, + "reduce_lr_min_lr": 1e-6, + "early_stopping_enabled": true, + "early_stopping_monitor": "val_loss", + "early_stopping_patience": 6, + "early_stopping_restore_best_weights": true, + "warmup_enabled": true, + "warmup_epochs": 5, + "warmup_start_lr": 1e-6 +}