mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-27 07:44:12 +01:00
Merge remote-tracking branch 'pixelwise_local/ReduceLROnPlateau' into ReduceLROnPlateau
# Conflicts: # LICENSE # README.md # requirements.txt # train.py
This commit is contained in:
commit
54132a499a
4 changed files with 174 additions and 37 deletions
0
.gitkeep
Normal file
0
.gitkeep
Normal file
0
__init__.py
Normal file
0
__init__.py
Normal file
156
train.py
156
train.py
|
|
@ -31,6 +31,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.compat.v1.keras.backend import set_session
|
from tensorflow.compat.v1.keras.backend import set_session
|
||||||
from tensorflow.keras.optimizers import SGD, Adam
|
from tensorflow.keras.optimizers import SGD, Adam
|
||||||
|
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback, ModelCheckpoint
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
@ -61,6 +62,20 @@ class SaveWeightsAfterSteps(Callback):
|
||||||
json.dump(self._config, fp) # encode dict into JSON
|
json.dump(self._config, fp) # encode dict into JSON
|
||||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||||
|
|
||||||
|
def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch):
|
||||||
|
initial_learning_rate = start_lr
|
||||||
|
target_learning_rate = target_lr
|
||||||
|
warmup_steps = warmup_epochs * steps_per_epoch
|
||||||
|
|
||||||
|
lr_schedule = tf.keras.optimizers.schedules.LinearSchedule(
|
||||||
|
initial_learning_rate=initial_learning_rate,
|
||||||
|
final_learning_rate=target_learning_rate,
|
||||||
|
total_steps=warmup_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
return lr_schedule
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def configuration():
|
def configuration():
|
||||||
config = tf.compat.v1.ConfigProto()
|
config = tf.compat.v1.ConfigProto()
|
||||||
|
|
@ -80,7 +95,6 @@ def get_dirs_or_files(input_data):
|
||||||
|
|
||||||
ex = Experiment(save_git_info=False)
|
ex = Experiment(save_git_info=False)
|
||||||
|
|
||||||
|
|
||||||
@ex.config
|
@ex.config
|
||||||
def config_params():
|
def config_params():
|
||||||
n_classes = None # Number of classes. In the case of binary classification this should be 2.
|
n_classes = None # Number of classes. In the case of binary classification this should be 2.
|
||||||
|
|
@ -145,6 +159,19 @@ def config_params():
|
||||||
number_of_backgrounds_per_image = 1
|
number_of_backgrounds_per_image = 1
|
||||||
dir_rgb_backgrounds = None
|
dir_rgb_backgrounds = None
|
||||||
dir_rgb_foregrounds = None
|
dir_rgb_foregrounds = None
|
||||||
|
reduce_lr_enabled = False # Whether to use ReduceLROnPlateau callback
|
||||||
|
reduce_lr_monitor = 'val_loss' # Metric to monitor for reducing learning rate
|
||||||
|
reduce_lr_factor = 0.5 # Factor to reduce learning rate by
|
||||||
|
reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate
|
||||||
|
reduce_lr_min_lr = 1e-6 # Minimum learning rate
|
||||||
|
reduce_lr_min_delta = 1e-5 # Minimum change in monitored value to be considered as improvement
|
||||||
|
early_stopping_enabled = False # Whether to use EarlyStopping callback
|
||||||
|
early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping
|
||||||
|
early_stopping_patience = 10 # Number of epochs to wait before stopping
|
||||||
|
early_stopping_restore_best_weights = True # Whether to restore best weights when stopping
|
||||||
|
warmup_enabled = False # Whether to use learning rate warmup
|
||||||
|
warmup_epochs = 5 # Number of epochs for warmup
|
||||||
|
warmup_start_lr = 1e-6 # Starting learning rate for warmup
|
||||||
|
|
||||||
@ex.automain
|
@ex.automain
|
||||||
def run(_config, n_classes, n_epochs, input_height,
|
def run(_config, n_classes, n_epochs, input_height,
|
||||||
|
|
@ -159,7 +186,10 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
|
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
|
||||||
transformer_patchsize_x, transformer_patchsize_y,
|
transformer_patchsize_x, transformer_patchsize_y,
|
||||||
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
||||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
|
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds,
|
||||||
|
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, reduce_lr_min_delta,
|
||||||
|
early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights,
|
||||||
|
warmup_enabled, warmup_epochs, warmup_start_lr):
|
||||||
|
|
||||||
if dir_rgb_backgrounds:
|
if dir_rgb_backgrounds:
|
||||||
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
||||||
|
|
@ -320,20 +350,91 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
|
# Create callbacks list
|
||||||
|
callbacks = []
|
||||||
|
if reduce_lr_enabled:
|
||||||
|
reduce_lr = ReduceLROnPlateau(
|
||||||
|
monitor=reduce_lr_monitor,
|
||||||
|
factor=reduce_lr_factor,
|
||||||
|
patience=reduce_lr_patience,
|
||||||
|
min_lr=reduce_lr_min_lr,
|
||||||
|
min_delta=reduce_lr_min_delta,
|
||||||
|
verbose=1
|
||||||
|
)
|
||||||
|
callbacks.append(reduce_lr)
|
||||||
|
|
||||||
|
if early_stopping_enabled:
|
||||||
|
early_stopping = EarlyStopping(
|
||||||
|
monitor=early_stopping_monitor,
|
||||||
|
patience=early_stopping_patience,
|
||||||
|
restore_best_weights=early_stopping_restore_best_weights,
|
||||||
|
verbose=1
|
||||||
|
)
|
||||||
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
|
# Add checkpoint to save models every epoch
|
||||||
|
class ModelCheckpointWithConfig(ModelCheckpoint):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._config = _config
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
super().on_epoch_end(epoch, logs)
|
||||||
|
model_dir = os.path.join(dir_output, f"model_{epoch+1}")
|
||||||
|
with open(os.path.join(model_dir, "config.json"), "w") as fp:
|
||||||
|
json.dump(self._config, fp)
|
||||||
|
|
||||||
|
checkpoint_epoch = ModelCheckpointWithConfig(
|
||||||
|
os.path.join(dir_output, "model_{epoch}"),
|
||||||
|
save_freq='epoch',
|
||||||
|
save_weights_only=False,
|
||||||
|
save_best_only=False,
|
||||||
|
verbose=1
|
||||||
|
)
|
||||||
|
callbacks.append(checkpoint_epoch)
|
||||||
|
|
||||||
|
# Calculate steps per epoch
|
||||||
|
steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1
|
||||||
|
|
||||||
|
# Create optimizer with or without warmup
|
||||||
|
if warmup_enabled:
|
||||||
|
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||||
|
initial_learning_rate=warmup_start_lr,
|
||||||
|
decay_steps=warmup_epochs * steps_per_epoch,
|
||||||
|
end_learning_rate=learning_rate,
|
||||||
|
power=1.0 # Linear decay
|
||||||
|
)
|
||||||
|
optimizer = Adam(learning_rate=lr_schedule)
|
||||||
|
else:
|
||||||
|
optimizer = Adam(learning_rate=learning_rate)
|
||||||
|
|
||||||
|
if (task == "segmentation" or task == "binarization"):
|
||||||
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
|
model.compile(loss='categorical_crossentropy',
|
||||||
|
optimizer=optimizer, metrics=['accuracy'])
|
||||||
|
if is_loss_soft_dice:
|
||||||
|
model.compile(loss=soft_dice_loss,
|
||||||
|
optimizer=optimizer, metrics=['accuracy'])
|
||||||
|
if weighted_loss:
|
||||||
|
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||||
|
optimizer=optimizer, metrics=['accuracy'])
|
||||||
|
elif task == "enhancement":
|
||||||
|
model.compile(loss='mean_squared_error',
|
||||||
|
optimizer=optimizer, metrics=['accuracy'])
|
||||||
|
|
||||||
if task == "segmentation" or task == "binarization":
|
if task == "segmentation" or task == "binarization":
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
model.compile(loss='categorical_crossentropy',
|
model.compile(loss='categorical_crossentropy',
|
||||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
optimizer=optimizer, metrics=['accuracy'])
|
||||||
if is_loss_soft_dice:
|
if is_loss_soft_dice:
|
||||||
model.compile(loss=soft_dice_loss,
|
model.compile(loss=soft_dice_loss,
|
||||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
optimizer=optimizer, metrics=['accuracy'])
|
||||||
if weighted_loss:
|
if weighted_loss:
|
||||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
optimizer=optimizer, metrics=['accuracy'])
|
||||||
elif task == "enhancement":
|
elif task == "enhancement":
|
||||||
model.compile(loss='mean_squared_error',
|
model.compile(loss='mean_squared_error',
|
||||||
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
|
optimizer=optimizer, metrics=['accuracy'])
|
||||||
|
|
||||||
|
|
||||||
# generating train and evaluation data
|
# generating train and evaluation data
|
||||||
|
|
@ -342,39 +443,22 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||||
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
||||||
|
|
||||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
# Single fit call with all epochs
|
||||||
##score_best=[]
|
history = model.fit(
|
||||||
##score_best.append(0)
|
train_gen,
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
validation_data=val_gen,
|
||||||
|
validation_steps=1,
|
||||||
|
epochs=n_epochs,
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
if save_interval:
|
# Save the best model (either from early stopping or final model)
|
||||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
model.save(os.path.join(dir_output, 'model_best'))
|
||||||
|
|
||||||
|
with open(os.path.join(dir_output, 'model_best', "config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
|
||||||
if save_interval:
|
|
||||||
model.fit(
|
|
||||||
train_gen,
|
|
||||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
|
||||||
validation_data=val_gen,
|
|
||||||
validation_steps=1,
|
|
||||||
epochs=1, callbacks=[save_weights_callback])
|
|
||||||
else:
|
|
||||||
model.fit(
|
|
||||||
train_gen,
|
|
||||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
|
||||||
validation_data=val_gen,
|
|
||||||
validation_steps=1,
|
|
||||||
epochs=1)
|
|
||||||
|
|
||||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
|
||||||
|
|
||||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
|
||||||
|
|
||||||
#os.system('rm -rf '+dir_train_flowing)
|
|
||||||
#os.system('rm -rf '+dir_eval_flowing)
|
|
||||||
|
|
||||||
#model.save(dir_output+'/'+'model'+'.h5')
|
|
||||||
elif task=='classification':
|
elif task=='classification':
|
||||||
configuration()
|
configuration()
|
||||||
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||||
|
|
|
||||||
53
train_no_patches_448x448.json
Normal file
53
train_no_patches_448x448.json
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
{
|
||||||
|
"backbone_type" : "nontransformer",
|
||||||
|
"task": "segmentation",
|
||||||
|
"n_classes" : 3,
|
||||||
|
"n_epochs" : 50,
|
||||||
|
"input_height" : 448,
|
||||||
|
"input_width" : 448,
|
||||||
|
"weight_decay" : 1e-4,
|
||||||
|
"n_batch" : 4,
|
||||||
|
"learning_rate": 2e-5,
|
||||||
|
"patches" : false,
|
||||||
|
"pretraining" : true,
|
||||||
|
"augmentation" : true,
|
||||||
|
"flip_aug" : false,
|
||||||
|
"blur_aug" : false,
|
||||||
|
"scaling" : true,
|
||||||
|
"degrading": true,
|
||||||
|
"brightening": false,
|
||||||
|
"binarization" : false,
|
||||||
|
"scaling_bluring" : false,
|
||||||
|
"scaling_binarization" : false,
|
||||||
|
"scaling_flip" : false,
|
||||||
|
"rotation": false,
|
||||||
|
"rotation_not_90": false,
|
||||||
|
"blur_k" : ["blur","guass","median"],
|
||||||
|
"scales" : [0.6, 0.7, 0.8, 0.9],
|
||||||
|
"brightness" : [1.3, 1.5, 1.7, 2],
|
||||||
|
"degrade_scales" : [0.2, 0.4],
|
||||||
|
"flip_index" : [0, 1, -1],
|
||||||
|
"thetha" : [10, -10],
|
||||||
|
"continue_training": false,
|
||||||
|
"index_start" : 0,
|
||||||
|
"dir_of_start_model" : " ",
|
||||||
|
"weighted_loss": false,
|
||||||
|
"is_loss_soft_dice": true,
|
||||||
|
"data_is_provided": true,
|
||||||
|
"dir_train": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/train",
|
||||||
|
"dir_eval": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/eval",
|
||||||
|
"dir_output": "runs/sam_41_mss_npt_448x448",
|
||||||
|
"reduce_lr_enabled": true,
|
||||||
|
"reduce_lr_monitor": "val_loss",
|
||||||
|
"reduce_lr_factor": 0.2,
|
||||||
|
"reduce_lr_patience": 3,
|
||||||
|
"reduce_lr_min_delta": 1e-5,
|
||||||
|
"reduce_lr_min_lr": 1e-6,
|
||||||
|
"early_stopping_enabled": true,
|
||||||
|
"early_stopping_monitor": "val_loss",
|
||||||
|
"early_stopping_patience": 6,
|
||||||
|
"early_stopping_restore_best_weights": true,
|
||||||
|
"warmup_enabled": true,
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"warmup_start_lr": 1e-6
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue