LR Warmup and Optimization Implementation

# Learning Rate Warmup and Optimization Implementation

## Overview
Added learning rate warmup functionality to improve training stability, especially when using pretrained weights. The implementation uses TensorFlow's native learning rate scheduling for better performance.

## Changes Made

### 1. Configuration Updates (`runs/train_no_patches_448x448.json`)
Added new configuration parameters for warmup:
```json
{
    "warmup_enabled": true,
    "warmup_epochs": 5,
    "warmup_start_lr": 1e-6
}
```

### 2. Training Script Updates (`train.py`)

#### A. Optimizer and Learning Rate Schedule
- Replaced fixed learning rate with dynamic scheduling
- Implemented warmup using `tf.keras.optimizers.schedules.PolynomialDecay`
- Maintained compatibility with existing ReduceLROnPlateau and EarlyStopping

```python
if warmup_enabled:
    lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=warmup_start_lr,
        decay_steps=warmup_epochs * steps_per_epoch,
        end_learning_rate=learning_rate,
        power=1.0  # Linear decay
    )
    optimizer = Adam(learning_rate=lr_schedule)
else:
    optimizer = Adam(learning_rate=learning_rate)
```

#### B. Learning Rate Behavior
- Initial learning rate: 1e-6 (configurable via `warmup_start_lr`)
- Target learning rate: 5e-5 (configurable via `learning_rate`)
- Linear increase over 5 epochs (configurable via `warmup_epochs`)
- After warmup, learning rate remains at target value until ReduceLROnPlateau triggers

## Benefits
1. Improved training stability during initial epochs
2. Better handling of pretrained weights
3. Efficient implementation using TensorFlow's native scheduling
4. Configurable through JSON configuration file
5. Maintains compatibility with existing callbacks (ReduceLROnPlateau, EarlyStopping)

## Usage
To enable warmup:
1. Set `warmup_enabled: true` in the configuration file
2. Adjust `warmup_epochs` and `warmup_start_lr` as needed
3. The warmup will automatically integrate with existing learning rate reduction and early stopping

To disable warmup:
- Set `warmup_enabled: false` or remove the warmup parameters from the configuration file
pull/25/head
johnlockejrr 7 days ago committed by GitHub
parent 1bf801985b
commit 7661080899
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,6 +5,7 @@ import tensorflow as tf
from tensorflow.compat.v1.keras.backend import set_session
import warnings
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback
from sacred import Experiment
from models import *
from utils import *
@ -15,6 +16,36 @@ import json
from sklearn.metrics import f1_score
def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch):
initial_learning_rate = start_lr
target_learning_rate = target_lr
warmup_steps = warmup_epochs * steps_per_epoch
lr_schedule = tf.keras.optimizers.schedules.LinearSchedule(
initial_learning_rate=initial_learning_rate,
final_learning_rate=target_learning_rate,
total_steps=warmup_steps
)
return lr_schedule
class WarmupScheduler(Callback):
def __init__(self, start_lr, target_lr, warmup_epochs):
super(WarmupScheduler, self).__init__()
self.start_lr = start_lr
self.target_lr = target_lr
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
def on_epoch_begin(self, epoch, logs=None):
if self.current_epoch < self.warmup_epochs:
# Linear warmup
lr = self.start_lr + (self.target_lr - self.start_lr) * (self.current_epoch / self.warmup_epochs)
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
self.current_epoch += 1
def configuration():
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
@ -97,6 +128,18 @@ def config_params():
number_of_backgrounds_per_image = 1
dir_rgb_backgrounds = None
dir_rgb_foregrounds = None
reduce_lr_enabled = False # Whether to use ReduceLROnPlateau callback
reduce_lr_monitor = 'val_loss' # Metric to monitor for reducing learning rate
reduce_lr_factor = 0.5 # Factor to reduce learning rate by
reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate
reduce_lr_min_lr = 1e-6 # Minimum learning rate
early_stopping_enabled = False # Whether to use EarlyStopping callback
early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping
early_stopping_patience = 10 # Number of epochs to wait before stopping
early_stopping_restore_best_weights = True # Whether to restore best weights when stopping
warmup_enabled = False # Whether to use learning rate warmup
warmup_epochs = 5 # Number of epochs for warmup
warmup_start_lr = 1e-6 # Starting learning rate for warmup
@ex.automain
@ -112,7 +155,10 @@ def run(_config, n_classes, n_epochs, input_height,
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
transformer_patchsize_x, transformer_patchsize_y,
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds,
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr,
early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights,
warmup_enabled, warmup_epochs, warmup_start_lr):
if dir_rgb_backgrounds:
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
@ -274,20 +320,55 @@ def run(_config, n_classes, n_epochs, input_height,
model.summary()
# Create callbacks list
callbacks = []
if reduce_lr_enabled:
reduce_lr = ReduceLROnPlateau(
monitor=reduce_lr_monitor,
factor=reduce_lr_factor,
patience=reduce_lr_patience,
min_lr=reduce_lr_min_lr,
verbose=1
)
callbacks.append(reduce_lr)
if early_stopping_enabled:
early_stopping = EarlyStopping(
monitor=early_stopping_monitor,
patience=early_stopping_patience,
restore_best_weights=early_stopping_restore_best_weights,
verbose=1
)
callbacks.append(early_stopping)
# Calculate steps per epoch
steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1
# Create optimizer with or without warmup
if warmup_enabled:
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=warmup_start_lr,
decay_steps=warmup_epochs * steps_per_epoch,
end_learning_rate=learning_rate,
power=1.0 # Linear decay
)
optimizer = Adam(learning_rate=lr_schedule)
else:
optimizer = Adam(learning_rate=learning_rate)
if (task == "segmentation" or task == "binarization"):
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
if is_loss_soft_dice:
model.compile(loss=soft_dice_loss,
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
elif task == "enhancement":
model.compile(loss='mean_squared_error',
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
# generating train and evaluation data
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
@ -298,13 +379,15 @@ def run(_config, n_classes, n_epochs, input_height,
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
##score_best=[]
##score_best.append(0)
for i in tqdm(range(index_start, n_epochs + index_start)):
model.fit(
train_gen,
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
steps_per_epoch=steps_per_epoch,
validation_data=val_gen,
validation_steps=1,
epochs=1)
epochs=1,
callbacks=callbacks)
model.save(os.path.join(dir_output,'model_'+str(i)))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:

@ -0,0 +1,52 @@
{
"backbone_type" : "nontransformer",
"task": "segmentation",
"n_classes" : 3,
"n_epochs" : 50,
"input_height" : 448,
"input_width" : 448,
"weight_decay" : 1e-4,
"n_batch" : 4,
"learning_rate": 5e-5,
"patches" : false,
"pretraining" : true,
"augmentation" : true,
"flip_aug" : false,
"blur_aug" : false,
"scaling" : true,
"degrading": true,
"brightening": false,
"binarization" : false,
"scaling_bluring" : false,
"scaling_binarization" : false,
"scaling_flip" : false,
"rotation": false,
"rotation_not_90": false,
"blur_k" : ["blur","guass","median"],
"scales" : [0.6, 0.7, 0.8, 0.9],
"brightness" : [1.3, 1.5, 1.7, 2],
"degrade_scales" : [0.2, 0.4],
"flip_index" : [0, 1, -1],
"thetha" : [10, -10],
"continue_training": false,
"index_start" : 0,
"dir_of_start_model" : " ",
"weighted_loss": false,
"is_loss_soft_dice": true,
"data_is_provided": true,
"dir_train": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/train",
"dir_eval": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/eval",
"dir_output": "runs/sam_41_mss_npt_448x448",
"reduce_lr_enabled": true,
"reduce_lr_monitor": "val_loss",
"reduce_lr_factor": 0.5,
"reduce_lr_patience": 4,
"reduce_lr_min_lr": 1e-6,
"early_stopping_enabled": true,
"early_stopping_monitor": "val_loss",
"early_stopping_patience": 6,
"early_stopping_restore_best_weights": true,
"warmup_enabled": true,
"warmup_epochs": 5,
"warmup_start_lr": 1e-6
}
Loading…
Cancel
Save