diff --git a/train.py b/train.py index 4cc3cbb..668d6aa 100644 --- a/train.py +++ b/train.py @@ -5,6 +5,7 @@ import tensorflow as tf from tensorflow.compat.v1.keras.backend import set_session import warnings from tensorflow.keras.optimizers import * +from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback from sacred import Experiment from models import * from utils import * @@ -15,6 +16,36 @@ import json from sklearn.metrics import f1_score +def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch): + initial_learning_rate = start_lr + target_learning_rate = target_lr + warmup_steps = warmup_epochs * steps_per_epoch + + lr_schedule = tf.keras.optimizers.schedules.LinearSchedule( + initial_learning_rate=initial_learning_rate, + final_learning_rate=target_learning_rate, + total_steps=warmup_steps + ) + + return lr_schedule + + +class WarmupScheduler(Callback): + def __init__(self, start_lr, target_lr, warmup_epochs): + super(WarmupScheduler, self).__init__() + self.start_lr = start_lr + self.target_lr = target_lr + self.warmup_epochs = warmup_epochs + self.current_epoch = 0 + + def on_epoch_begin(self, epoch, logs=None): + if self.current_epoch < self.warmup_epochs: + # Linear warmup + lr = self.start_lr + (self.target_lr - self.start_lr) * (self.current_epoch / self.warmup_epochs) + tf.keras.backend.set_value(self.model.optimizer.lr, lr) + self.current_epoch += 1 + + def configuration(): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True @@ -97,6 +128,18 @@ def config_params(): number_of_backgrounds_per_image = 1 dir_rgb_backgrounds = None dir_rgb_foregrounds = None + reduce_lr_enabled = False # Whether to use ReduceLROnPlateau callback + reduce_lr_monitor = 'val_loss' # Metric to monitor for reducing learning rate + reduce_lr_factor = 0.5 # Factor to reduce learning rate by + reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate + reduce_lr_min_lr = 1e-6 # Minimum learning rate + early_stopping_enabled = False # Whether to use EarlyStopping callback + early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping + early_stopping_patience = 10 # Number of epochs to wait before stopping + early_stopping_restore_best_weights = True # Whether to restore best weights when stopping + warmup_enabled = False # Whether to use learning rate warmup + warmup_epochs = 5 # Number of epochs for warmup + warmup_start_lr = 1e-6 # Starting learning rate for warmup @ex.automain @@ -112,7 +155,10 @@ def run(_config, n_classes, n_epochs, input_height, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output, - pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds, + reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, + early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights, + warmup_enabled, warmup_epochs, warmup_start_lr): if dir_rgb_backgrounds: list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) @@ -274,20 +320,55 @@ def run(_config, n_classes, n_epochs, input_height, model.summary() + # Create callbacks list + callbacks = [] + if reduce_lr_enabled: + reduce_lr = ReduceLROnPlateau( + monitor=reduce_lr_monitor, + factor=reduce_lr_factor, + patience=reduce_lr_patience, + min_lr=reduce_lr_min_lr, + verbose=1 + ) + callbacks.append(reduce_lr) + + if early_stopping_enabled: + early_stopping = EarlyStopping( + monitor=early_stopping_monitor, + patience=early_stopping_patience, + restore_best_weights=early_stopping_restore_best_weights, + verbose=1 + ) + callbacks.append(early_stopping) + + # Calculate steps per epoch + steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1 + + # Create optimizer with or without warmup + if warmup_enabled: + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=warmup_start_lr, + decay_steps=warmup_epochs * steps_per_epoch, + end_learning_rate=learning_rate, + power=1.0 # Linear decay + ) + optimizer = Adam(learning_rate=lr_schedule) + else: + optimizer = Adam(learning_rate=learning_rate) + if (task == "segmentation" or task == "binarization"): if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) if is_loss_soft_dice: model.compile(loss=soft_dice_loss, - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) if weighted_loss: model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) elif task == "enhancement": model.compile(loss='mean_squared_error', - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) - + optimizer=optimizer, metrics=['accuracy']) # generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, @@ -298,13 +379,15 @@ def run(_config, n_classes, n_epochs, input_height, ##img_validation_patches = os.listdir(dir_flow_eval_imgs) ##score_best=[] ##score_best.append(0) + for i in tqdm(range(index_start, n_epochs + index_start)): model.fit( train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, + steps_per_epoch=steps_per_epoch, validation_data=val_gen, validation_steps=1, - epochs=1) + epochs=1, + callbacks=callbacks) model.save(os.path.join(dir_output,'model_'+str(i))) with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: diff --git a/train_no_patches_448x448.json b/train_no_patches_448x448.json new file mode 100644 index 0000000..c3d0e10 --- /dev/null +++ b/train_no_patches_448x448.json @@ -0,0 +1,52 @@ +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 3, + "n_epochs" : 50, + "input_height" : 448, + "input_width" : 448, + "weight_decay" : 1e-4, + "n_batch" : 4, + "learning_rate": 5e-5, + "patches" : false, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": true, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": true, + "data_is_provided": true, + "dir_train": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/train", + "dir_eval": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/eval", + "dir_output": "runs/sam_41_mss_npt_448x448", + "reduce_lr_enabled": true, + "reduce_lr_monitor": "val_loss", + "reduce_lr_factor": 0.5, + "reduce_lr_patience": 4, + "reduce_lr_min_lr": 1e-6, + "early_stopping_enabled": true, + "early_stopping_monitor": "val_loss", + "early_stopping_patience": 6, + "early_stopping_restore_best_weights": true, + "warmup_enabled": true, + "warmup_epochs": 5, + "warmup_start_lr": 1e-6 +}