mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-20 01:00:02 +02:00
LR Warmup and Optimization Implementation
# Learning Rate Warmup and Optimization Implementation ## Overview Added learning rate warmup functionality to improve training stability, especially when using pretrained weights. The implementation uses TensorFlow's native learning rate scheduling for better performance. ## Changes Made ### 1. Configuration Updates (`runs/train_no_patches_448x448.json`) Added new configuration parameters for warmup: ```json { "warmup_enabled": true, "warmup_epochs": 5, "warmup_start_lr": 1e-6 } ``` ### 2. Training Script Updates (`train.py`) #### A. Optimizer and Learning Rate Schedule - Replaced fixed learning rate with dynamic scheduling - Implemented warmup using `tf.keras.optimizers.schedules.PolynomialDecay` - Maintained compatibility with existing ReduceLROnPlateau and EarlyStopping ```python if warmup_enabled: lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=warmup_start_lr, decay_steps=warmup_epochs * steps_per_epoch, end_learning_rate=learning_rate, power=1.0 # Linear decay ) optimizer = Adam(learning_rate=lr_schedule) else: optimizer = Adam(learning_rate=learning_rate) ``` #### B. Learning Rate Behavior - Initial learning rate: 1e-6 (configurable via `warmup_start_lr`) - Target learning rate: 5e-5 (configurable via `learning_rate`) - Linear increase over 5 epochs (configurable via `warmup_epochs`) - After warmup, learning rate remains at target value until ReduceLROnPlateau triggers ## Benefits 1. Improved training stability during initial epochs 2. Better handling of pretrained weights 3. Efficient implementation using TensorFlow's native scheduling 4. Configurable through JSON configuration file 5. Maintains compatibility with existing callbacks (ReduceLROnPlateau, EarlyStopping) ## Usage To enable warmup: 1. Set `warmup_enabled: true` in the configuration file 2. Adjust `warmup_epochs` and `warmup_start_lr` as needed 3. The warmup will automatically integrate with existing learning rate reduction and early stopping To disable warmup: - Set `warmup_enabled: false` or remove the warmup parameters from the configuration file
This commit is contained in:
parent
1bf801985b
commit
7661080899
2 changed files with 143 additions and 8 deletions
99
train.py
99
train.py
|
@ -5,6 +5,7 @@ import tensorflow as tf
|
|||
from tensorflow.compat.v1.keras.backend import set_session
|
||||
import warnings
|
||||
from tensorflow.keras.optimizers import *
|
||||
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback
|
||||
from sacred import Experiment
|
||||
from models import *
|
||||
from utils import *
|
||||
|
@ -15,6 +16,36 @@ import json
|
|||
from sklearn.metrics import f1_score
|
||||
|
||||
|
||||
def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch):
|
||||
initial_learning_rate = start_lr
|
||||
target_learning_rate = target_lr
|
||||
warmup_steps = warmup_epochs * steps_per_epoch
|
||||
|
||||
lr_schedule = tf.keras.optimizers.schedules.LinearSchedule(
|
||||
initial_learning_rate=initial_learning_rate,
|
||||
final_learning_rate=target_learning_rate,
|
||||
total_steps=warmup_steps
|
||||
)
|
||||
|
||||
return lr_schedule
|
||||
|
||||
|
||||
class WarmupScheduler(Callback):
|
||||
def __init__(self, start_lr, target_lr, warmup_epochs):
|
||||
super(WarmupScheduler, self).__init__()
|
||||
self.start_lr = start_lr
|
||||
self.target_lr = target_lr
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
if self.current_epoch < self.warmup_epochs:
|
||||
# Linear warmup
|
||||
lr = self.start_lr + (self.target_lr - self.start_lr) * (self.current_epoch / self.warmup_epochs)
|
||||
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
|
||||
self.current_epoch += 1
|
||||
|
||||
|
||||
def configuration():
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
@ -97,6 +128,18 @@ def config_params():
|
|||
number_of_backgrounds_per_image = 1
|
||||
dir_rgb_backgrounds = None
|
||||
dir_rgb_foregrounds = None
|
||||
reduce_lr_enabled = False # Whether to use ReduceLROnPlateau callback
|
||||
reduce_lr_monitor = 'val_loss' # Metric to monitor for reducing learning rate
|
||||
reduce_lr_factor = 0.5 # Factor to reduce learning rate by
|
||||
reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate
|
||||
reduce_lr_min_lr = 1e-6 # Minimum learning rate
|
||||
early_stopping_enabled = False # Whether to use EarlyStopping callback
|
||||
early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping
|
||||
early_stopping_patience = 10 # Number of epochs to wait before stopping
|
||||
early_stopping_restore_best_weights = True # Whether to restore best weights when stopping
|
||||
warmup_enabled = False # Whether to use learning rate warmup
|
||||
warmup_epochs = 5 # Number of epochs for warmup
|
||||
warmup_start_lr = 1e-6 # Starting learning rate for warmup
|
||||
|
||||
|
||||
@ex.automain
|
||||
|
@ -112,7 +155,10 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
|
||||
transformer_patchsize_x, transformer_patchsize_y,
|
||||
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds,
|
||||
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr,
|
||||
early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights,
|
||||
warmup_enabled, warmup_epochs, warmup_start_lr):
|
||||
|
||||
if dir_rgb_backgrounds:
|
||||
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
||||
|
@ -274,20 +320,55 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
model.summary()
|
||||
|
||||
|
||||
# Create callbacks list
|
||||
callbacks = []
|
||||
if reduce_lr_enabled:
|
||||
reduce_lr = ReduceLROnPlateau(
|
||||
monitor=reduce_lr_monitor,
|
||||
factor=reduce_lr_factor,
|
||||
patience=reduce_lr_patience,
|
||||
min_lr=reduce_lr_min_lr,
|
||||
verbose=1
|
||||
)
|
||||
callbacks.append(reduce_lr)
|
||||
|
||||
if early_stopping_enabled:
|
||||
early_stopping = EarlyStopping(
|
||||
monitor=early_stopping_monitor,
|
||||
patience=early_stopping_patience,
|
||||
restore_best_weights=early_stopping_restore_best_weights,
|
||||
verbose=1
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
# Calculate steps per epoch
|
||||
steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1
|
||||
|
||||
# Create optimizer with or without warmup
|
||||
if warmup_enabled:
|
||||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=warmup_start_lr,
|
||||
decay_steps=warmup_epochs * steps_per_epoch,
|
||||
end_learning_rate=learning_rate,
|
||||
power=1.0 # Linear decay
|
||||
)
|
||||
optimizer = Adam(learning_rate=lr_schedule)
|
||||
else:
|
||||
optimizer = Adam(learning_rate=learning_rate)
|
||||
|
||||
if (task == "segmentation" or task == "binarization"):
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
optimizer=optimizer, metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
optimizer=optimizer, metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
optimizer=optimizer, metrics=['accuracy'])
|
||||
elif task == "enhancement":
|
||||
model.compile(loss='mean_squared_error',
|
||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
||||
|
||||
optimizer=optimizer, metrics=['accuracy'])
|
||||
|
||||
# generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
||||
|
@ -298,13 +379,15 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
epochs=1,
|
||||
callbacks=callbacks)
|
||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
|
|
52
train_no_patches_448x448.json
Normal file
52
train_no_patches_448x448.json
Normal file
|
@ -0,0 +1,52 @@
|
|||
{
|
||||
"backbone_type" : "nontransformer",
|
||||
"task": "segmentation",
|
||||
"n_classes" : 3,
|
||||
"n_epochs" : 50,
|
||||
"input_height" : 448,
|
||||
"input_width" : 448,
|
||||
"weight_decay" : 1e-4,
|
||||
"n_batch" : 4,
|
||||
"learning_rate": 5e-5,
|
||||
"patches" : false,
|
||||
"pretraining" : true,
|
||||
"augmentation" : true,
|
||||
"flip_aug" : false,
|
||||
"blur_aug" : false,
|
||||
"scaling" : true,
|
||||
"degrading": true,
|
||||
"brightening": false,
|
||||
"binarization" : false,
|
||||
"scaling_bluring" : false,
|
||||
"scaling_binarization" : false,
|
||||
"scaling_flip" : false,
|
||||
"rotation": false,
|
||||
"rotation_not_90": false,
|
||||
"blur_k" : ["blur","guass","median"],
|
||||
"scales" : [0.6, 0.7, 0.8, 0.9],
|
||||
"brightness" : [1.3, 1.5, 1.7, 2],
|
||||
"degrade_scales" : [0.2, 0.4],
|
||||
"flip_index" : [0, 1, -1],
|
||||
"thetha" : [10, -10],
|
||||
"continue_training": false,
|
||||
"index_start" : 0,
|
||||
"dir_of_start_model" : " ",
|
||||
"weighted_loss": false,
|
||||
"is_loss_soft_dice": true,
|
||||
"data_is_provided": true,
|
||||
"dir_train": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/train",
|
||||
"dir_eval": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/eval",
|
||||
"dir_output": "runs/sam_41_mss_npt_448x448",
|
||||
"reduce_lr_enabled": true,
|
||||
"reduce_lr_monitor": "val_loss",
|
||||
"reduce_lr_factor": 0.5,
|
||||
"reduce_lr_patience": 4,
|
||||
"reduce_lr_min_lr": 1e-6,
|
||||
"early_stopping_enabled": true,
|
||||
"early_stopping_monitor": "val_loss",
|
||||
"early_stopping_patience": 6,
|
||||
"early_stopping_restore_best_weights": true,
|
||||
"warmup_enabled": true,
|
||||
"warmup_epochs": 5,
|
||||
"warmup_start_lr": 1e-6
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue