Fix `ReduceONPlateau` wrong logic

# Training Script Improvements

## Learning Rate Management Fixes

### 1. ReduceLROnPlateau Implementation
- Fixed the learning rate reduction mechanism by replacing the manual epoch loop with a single `model.fit()` call
- This ensures proper tracking of validation metrics across epochs
- Configured with:
  ```python
  reduce_lr = ReduceLROnPlateau(
      monitor='val_loss',
      factor=0.2,        # More aggressive reduction
      patience=3,        # Quick response to plateaus
      min_lr=1e-6,       # Minimum learning rate
      min_delta=1e-5,    # Minimum change to be considered improvement
      verbose=1
  )
  ```

### 2. Warmup Implementation
- Added learning rate warmup using TensorFlow's native scheduling
- Gradually increases learning rate from 1e-6 to target (2e-5) over 5 epochs
- Helps stabilize initial training phase
- Implemented using `PolynomialDecay` schedule:
  ```python
  lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
      initial_learning_rate=warmup_start_lr,
      decay_steps=warmup_epochs * steps_per_epoch,
      end_learning_rate=learning_rate,
      power=1.0  # Linear decay
  )
  ```

### 3. Early Stopping
- Added early stopping to prevent overfitting
- Configured with:
  ```python
  early_stopping = EarlyStopping(
      monitor='val_loss',
      patience=6,
      restore_best_weights=True,
      verbose=1
  )
  ```

## Model Saving Improvements

### 1. Epoch-based Model Saving
- Implemented custom `ModelCheckpointWithConfig` to save both model and config
- Saves after each epoch with corresponding config.json
- Maintains compatibility with original script's saving behavior

### 2. Best Model Saving
- Saves the best model at training end
- If early stopping triggers: saves the best model from training
- If no early stopping: saves the final model

## Configuration
All parameters are configurable through the JSON config file:
```json
{
    "reduce_lr_enabled": true,
    "reduce_lr_monitor": "val_loss",
    "reduce_lr_factor": 0.2,
    "reduce_lr_patience": 3,
    "reduce_lr_min_lr": 1e-6,
    "reduce_lr_min_delta": 1e-5,
    "early_stopping_enabled": true,
    "early_stopping_monitor": "val_loss",
    "early_stopping_patience": 6,
    "early_stopping_restore_best_weights": true,
    "warmup_enabled": true,
    "warmup_epochs": 5,
    "warmup_start_lr": 1e-6
}
```

## Benefits
1. More stable training with proper learning rate management
2. Better handling of training plateaus
3. Automatic saving of best model
4. Maintained compatibility with existing config saving
5. Improved training monitoring and control
pull/25/head
johnlockejrr 7 days ago committed by GitHub
parent 7661080899
commit f298643fcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,7 +5,7 @@ import tensorflow as tf
from tensorflow.compat.v1.keras.backend import set_session
import warnings
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback, ModelCheckpoint
from sacred import Experiment
from models import *
from utils import *
@ -30,22 +30,6 @@ def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch):
return lr_schedule
class WarmupScheduler(Callback):
def __init__(self, start_lr, target_lr, warmup_epochs):
super(WarmupScheduler, self).__init__()
self.start_lr = start_lr
self.target_lr = target_lr
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
def on_epoch_begin(self, epoch, logs=None):
if self.current_epoch < self.warmup_epochs:
# Linear warmup
lr = self.start_lr + (self.target_lr - self.start_lr) * (self.current_epoch / self.warmup_epochs)
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
self.current_epoch += 1
def configuration():
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
@ -133,6 +117,7 @@ def config_params():
reduce_lr_factor = 0.5 # Factor to reduce learning rate by
reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate
reduce_lr_min_lr = 1e-6 # Minimum learning rate
reduce_lr_min_delta = 1e-5 # Minimum change in monitored value to be considered as improvement
early_stopping_enabled = False # Whether to use EarlyStopping callback
early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping
early_stopping_patience = 10 # Number of epochs to wait before stopping
@ -156,7 +141,7 @@ def run(_config, n_classes, n_epochs, input_height,
transformer_patchsize_x, transformer_patchsize_y,
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds,
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr,
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, reduce_lr_min_delta,
early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights,
warmup_enabled, warmup_epochs, warmup_start_lr):
@ -328,6 +313,7 @@ def run(_config, n_classes, n_epochs, input_height,
factor=reduce_lr_factor,
patience=reduce_lr_patience,
min_lr=reduce_lr_min_lr,
min_delta=reduce_lr_min_delta,
verbose=1
)
callbacks.append(reduce_lr)
@ -341,6 +327,27 @@ def run(_config, n_classes, n_epochs, input_height,
)
callbacks.append(early_stopping)
# Add checkpoint to save models every epoch
class ModelCheckpointWithConfig(ModelCheckpoint):
def __init__(self, *args, **kwargs):
self._config = _config
super().__init__(*args, **kwargs)
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
model_dir = os.path.join(dir_output, f"model_{epoch+1}")
with open(os.path.join(model_dir, "config.json"), "w") as fp:
json.dump(self._config, fp)
checkpoint_epoch = ModelCheckpointWithConfig(
os.path.join(dir_output, "model_{epoch}"),
save_freq='epoch',
save_weights_only=False,
save_best_only=False,
verbose=1
)
callbacks.append(checkpoint_epoch)
# Calculate steps per epoch
steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1
@ -376,27 +383,22 @@ def run(_config, n_classes, n_epochs, input_height,
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
##score_best=[]
##score_best.append(0)
# Single fit call with all epochs
history = model.fit(
train_gen,
steps_per_epoch=steps_per_epoch,
validation_data=val_gen,
validation_steps=1,
epochs=n_epochs,
callbacks=callbacks
)
for i in tqdm(range(index_start, n_epochs + index_start)):
model.fit(
train_gen,
steps_per_epoch=steps_per_epoch,
validation_data=val_gen,
validation_steps=1,
epochs=1,
callbacks=callbacks)
model.save(os.path.join(dir_output,'model_'+str(i)))
# Save the best model (either from early stopping or final model)
model.save(os.path.join(dir_output, 'model_best'))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
#os.system('rm -rf '+dir_train_flowing)
#os.system('rm -rf '+dir_eval_flowing)
with open(os.path.join(dir_output, 'model_best', "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
#model.save(dir_output+'/'+'model'+'.h5')
elif task=='classification':
configuration()
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)

@ -7,7 +7,7 @@
"input_width" : 448,
"weight_decay" : 1e-4,
"n_batch" : 4,
"learning_rate": 5e-5,
"learning_rate": 2e-5,
"patches" : false,
"pretraining" : true,
"augmentation" : true,
@ -39,8 +39,9 @@
"dir_output": "runs/sam_41_mss_npt_448x448",
"reduce_lr_enabled": true,
"reduce_lr_monitor": "val_loss",
"reduce_lr_factor": 0.5,
"reduce_lr_patience": 4,
"reduce_lr_factor": 0.2,
"reduce_lr_patience": 3,
"reduce_lr_min_delta": 1e-5,
"reduce_lr_min_lr": 1e-6,
"early_stopping_enabled": true,
"early_stopping_monitor": "val_loss",

Loading…
Cancel
Save