diff --git a/train.py b/train.py index 668d6aa..177f408 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ import tensorflow as tf from tensorflow.compat.v1.keras.backend import set_session import warnings from tensorflow.keras.optimizers import * -from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback +from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback, ModelCheckpoint from sacred import Experiment from models import * from utils import * @@ -30,22 +30,6 @@ def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch): return lr_schedule -class WarmupScheduler(Callback): - def __init__(self, start_lr, target_lr, warmup_epochs): - super(WarmupScheduler, self).__init__() - self.start_lr = start_lr - self.target_lr = target_lr - self.warmup_epochs = warmup_epochs - self.current_epoch = 0 - - def on_epoch_begin(self, epoch, logs=None): - if self.current_epoch < self.warmup_epochs: - # Linear warmup - lr = self.start_lr + (self.target_lr - self.start_lr) * (self.current_epoch / self.warmup_epochs) - tf.keras.backend.set_value(self.model.optimizer.lr, lr) - self.current_epoch += 1 - - def configuration(): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True @@ -133,6 +117,7 @@ def config_params(): reduce_lr_factor = 0.5 # Factor to reduce learning rate by reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate reduce_lr_min_lr = 1e-6 # Minimum learning rate + reduce_lr_min_delta = 1e-5 # Minimum change in monitored value to be considered as improvement early_stopping_enabled = False # Whether to use EarlyStopping callback early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping early_stopping_patience = 10 # Number of epochs to wait before stopping @@ -156,7 +141,7 @@ def run(_config, n_classes, n_epochs, input_height, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds, - reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, + reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, reduce_lr_min_delta, early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights, warmup_enabled, warmup_epochs, warmup_start_lr): @@ -328,6 +313,7 @@ def run(_config, n_classes, n_epochs, input_height, factor=reduce_lr_factor, patience=reduce_lr_patience, min_lr=reduce_lr_min_lr, + min_delta=reduce_lr_min_delta, verbose=1 ) callbacks.append(reduce_lr) @@ -341,6 +327,27 @@ def run(_config, n_classes, n_epochs, input_height, ) callbacks.append(early_stopping) + # Add checkpoint to save models every epoch + class ModelCheckpointWithConfig(ModelCheckpoint): + def __init__(self, *args, **kwargs): + self._config = _config + super().__init__(*args, **kwargs) + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + model_dir = os.path.join(dir_output, f"model_{epoch+1}") + with open(os.path.join(model_dir, "config.json"), "w") as fp: + json.dump(self._config, fp) + + checkpoint_epoch = ModelCheckpointWithConfig( + os.path.join(dir_output, "model_{epoch}"), + save_freq='epoch', + save_weights_only=False, + save_best_only=False, + verbose=1 + ) + callbacks.append(checkpoint_epoch) + # Calculate steps per epoch steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1 @@ -376,27 +383,22 @@ def run(_config, n_classes, n_epochs, input_height, val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, input_height=input_height, input_width=input_width, n_classes=n_classes, task=task) - ##img_validation_patches = os.listdir(dir_flow_eval_imgs) - ##score_best=[] - ##score_best.append(0) - - for i in tqdm(range(index_start, n_epochs + index_start)): - model.fit( - train_gen, - steps_per_epoch=steps_per_epoch, - validation_data=val_gen, - validation_steps=1, - epochs=1, - callbacks=callbacks) - model.save(os.path.join(dir_output,'model_'+str(i))) + # Single fit call with all epochs + history = model.fit( + train_gen, + steps_per_epoch=steps_per_epoch, + validation_data=val_gen, + validation_steps=1, + epochs=n_epochs, + callbacks=callbacks + ) - with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON - - #os.system('rm -rf '+dir_train_flowing) - #os.system('rm -rf '+dir_eval_flowing) - - #model.save(dir_output+'/'+'model'+'.h5') + # Save the best model (either from early stopping or final model) + model.save(os.path.join(dir_output, 'model_best')) + + with open(os.path.join(dir_output, 'model_best', "config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + elif task=='classification': configuration() model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining) diff --git a/train_no_patches_448x448.json b/train_no_patches_448x448.json index c3d0e10..aaab12b 100644 --- a/train_no_patches_448x448.json +++ b/train_no_patches_448x448.json @@ -7,7 +7,7 @@ "input_width" : 448, "weight_decay" : 1e-4, "n_batch" : 4, - "learning_rate": 5e-5, + "learning_rate": 2e-5, "patches" : false, "pretraining" : true, "augmentation" : true, @@ -39,8 +39,9 @@ "dir_output": "runs/sam_41_mss_npt_448x448", "reduce_lr_enabled": true, "reduce_lr_monitor": "val_loss", - "reduce_lr_factor": 0.5, - "reduce_lr_patience": 4, + "reduce_lr_factor": 0.2, + "reduce_lr_patience": 3, + "reduce_lr_min_delta": 1e-5, "reduce_lr_min_lr": 1e-6, "early_stopping_enabled": true, "early_stopping_monitor": "val_loss",