diff --git a/.gitkeep b/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 97736e0..6102d31 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -31,6 +31,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf from tensorflow.compat.v1.keras.backend import set_session from tensorflow.keras.optimizers import SGD, Adam +from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback, ModelCheckpoint from sacred import Experiment from tensorflow.keras.models import load_model from tqdm import tqdm @@ -61,6 +62,20 @@ class SaveWeightsAfterSteps(Callback): json.dump(self._config, fp) # encode dict into JSON print(f"saved model as steps {self.step_count} to {save_file}") +def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch): + initial_learning_rate = start_lr + target_learning_rate = target_lr + warmup_steps = warmup_epochs * steps_per_epoch + + lr_schedule = tf.keras.optimizers.schedules.LinearSchedule( + initial_learning_rate=initial_learning_rate, + final_learning_rate=target_learning_rate, + total_steps=warmup_steps + ) + + return lr_schedule + + def configuration(): config = tf.compat.v1.ConfigProto() @@ -80,7 +95,6 @@ def get_dirs_or_files(input_data): ex = Experiment(save_git_info=False) - @ex.config def config_params(): n_classes = None # Number of classes. In the case of binary classification this should be 2. @@ -145,6 +159,19 @@ def config_params(): number_of_backgrounds_per_image = 1 dir_rgb_backgrounds = None dir_rgb_foregrounds = None + reduce_lr_enabled = False # Whether to use ReduceLROnPlateau callback + reduce_lr_monitor = 'val_loss' # Metric to monitor for reducing learning rate + reduce_lr_factor = 0.5 # Factor to reduce learning rate by + reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate + reduce_lr_min_lr = 1e-6 # Minimum learning rate + reduce_lr_min_delta = 1e-5 # Minimum change in monitored value to be considered as improvement + early_stopping_enabled = False # Whether to use EarlyStopping callback + early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping + early_stopping_patience = 10 # Number of epochs to wait before stopping + early_stopping_restore_best_weights = True # Whether to restore best weights when stopping + warmup_enabled = False # Whether to use learning rate warmup + warmup_epochs = 5 # Number of epochs for warmup + warmup_start_lr = 1e-6 # Starting learning rate for warmup @ex.automain def run(_config, n_classes, n_epochs, input_height, @@ -159,7 +186,10 @@ def run(_config, n_classes, n_epochs, input_height, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, transformer_patchsize_x, transformer_patchsize_y, transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output, - pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds): + pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds, + reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, reduce_lr_min_delta, + early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights, + warmup_enabled, warmup_epochs, warmup_start_lr): if dir_rgb_backgrounds: list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) @@ -320,20 +350,91 @@ def run(_config, n_classes, n_epochs, input_height, #if you want to see the model structure just uncomment model summary. model.summary() + # Create callbacks list + callbacks = [] + if reduce_lr_enabled: + reduce_lr = ReduceLROnPlateau( + monitor=reduce_lr_monitor, + factor=reduce_lr_factor, + patience=reduce_lr_patience, + min_lr=reduce_lr_min_lr, + min_delta=reduce_lr_min_delta, + verbose=1 + ) + callbacks.append(reduce_lr) + + if early_stopping_enabled: + early_stopping = EarlyStopping( + monitor=early_stopping_monitor, + patience=early_stopping_patience, + restore_best_weights=early_stopping_restore_best_weights, + verbose=1 + ) + callbacks.append(early_stopping) + + # Add checkpoint to save models every epoch + class ModelCheckpointWithConfig(ModelCheckpoint): + def __init__(self, *args, **kwargs): + self._config = _config + super().__init__(*args, **kwargs) + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + model_dir = os.path.join(dir_output, f"model_{epoch+1}") + with open(os.path.join(model_dir, "config.json"), "w") as fp: + json.dump(self._config, fp) + + checkpoint_epoch = ModelCheckpointWithConfig( + os.path.join(dir_output, "model_{epoch}"), + save_freq='epoch', + save_weights_only=False, + save_best_only=False, + verbose=1 + ) + callbacks.append(checkpoint_epoch) + + # Calculate steps per epoch + steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1 + + # Create optimizer with or without warmup + if warmup_enabled: + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=warmup_start_lr, + decay_steps=warmup_epochs * steps_per_epoch, + end_learning_rate=learning_rate, + power=1.0 # Linear decay + ) + optimizer = Adam(learning_rate=lr_schedule) + else: + optimizer = Adam(learning_rate=learning_rate) + + if (task == "segmentation" or task == "binarization"): + if not is_loss_soft_dice and not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer=optimizer, metrics=['accuracy']) + if is_loss_soft_dice: + model.compile(loss=soft_dice_loss, + optimizer=optimizer, metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer=optimizer, metrics=['accuracy']) + elif task == "enhancement": + model.compile(loss='mean_squared_error', + optimizer=optimizer, metrics=['accuracy']) if task == "segmentation" or task == "binarization": if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) if is_loss_soft_dice: model.compile(loss=soft_dice_loss, - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) if weighted_loss: model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) elif task == "enhancement": model.compile(loss='mean_squared_error', - optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy']) # generating train and evaluation data @@ -342,39 +443,22 @@ def run(_config, n_classes, n_epochs, input_height, val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, input_height=input_height, input_width=input_width, n_classes=n_classes, task=task) - ##img_validation_patches = os.listdir(dir_flow_eval_imgs) - ##score_best=[] - ##score_best.append(0) + # Single fit call with all epochs + history = model.fit( + train_gen, + steps_per_epoch=steps_per_epoch, + validation_data=val_gen, + validation_steps=1, + epochs=n_epochs, + callbacks=callbacks + ) - if save_interval: - save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) - - - for i in tqdm(range(index_start, n_epochs + index_start)): - if save_interval: - model.fit( - train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, - validation_data=val_gen, - validation_steps=1, - epochs=1, callbacks=[save_weights_callback]) - else: - model.fit( - train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1, - validation_data=val_gen, - validation_steps=1, - epochs=1) - - model.save(os.path.join(dir_output,'model_'+str(i))) + # Save the best model (either from early stopping or final model) + model.save(os.path.join(dir_output, 'model_best')) - with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON - - #os.system('rm -rf '+dir_train_flowing) - #os.system('rm -rf '+dir_eval_flowing) - - #model.save(dir_output+'/'+'model'+'.h5') + with open(os.path.join(dir_output, 'model_best', "config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON + elif task=='classification': configuration() model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining) diff --git a/train_no_patches_448x448.json b/train_no_patches_448x448.json new file mode 100644 index 0000000..aaab12b --- /dev/null +++ b/train_no_patches_448x448.json @@ -0,0 +1,53 @@ +{ + "backbone_type" : "nontransformer", + "task": "segmentation", + "n_classes" : 3, + "n_epochs" : 50, + "input_height" : 448, + "input_width" : 448, + "weight_decay" : 1e-4, + "n_batch" : 4, + "learning_rate": 2e-5, + "patches" : false, + "pretraining" : true, + "augmentation" : true, + "flip_aug" : false, + "blur_aug" : false, + "scaling" : true, + "degrading": true, + "brightening": false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "scaling_flip" : false, + "rotation": false, + "rotation_not_90": false, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], + "continue_training": false, + "index_start" : 0, + "dir_of_start_model" : " ", + "weighted_loss": false, + "is_loss_soft_dice": true, + "data_is_provided": true, + "dir_train": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/train", + "dir_eval": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/eval", + "dir_output": "runs/sam_41_mss_npt_448x448", + "reduce_lr_enabled": true, + "reduce_lr_monitor": "val_loss", + "reduce_lr_factor": 0.2, + "reduce_lr_patience": 3, + "reduce_lr_min_delta": 1e-5, + "reduce_lr_min_lr": 1e-6, + "early_stopping_enabled": true, + "early_stopping_monitor": "val_loss", + "early_stopping_patience": 6, + "early_stopping_restore_best_weights": true, + "warmup_enabled": true, + "warmup_epochs": 5, + "warmup_start_lr": 1e-6 +}