training: re-instate index_start, reflect cfg dependency

- `index_start`: re-introduce cfg key, pass to Keras `Model.fit`
  as `initial_epoch`
- make config keys `index_start` and `dir_of_start_model` dependent
  on `continue_training`
- improve description
This commit is contained in:
Robert Sachunsky 2026-02-04 17:32:24 +01:00
parent 25153ad307
commit e85003db4a

View file

@ -157,10 +157,12 @@ def config_params():
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved. dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder. pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
save_interval = None # frequency for writing model checkpoints (nonzero integer for number of batches, or zero for epoch) save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}")
continue_training = False # Set to true if you would like to continue training an already trained a model. continue_training = False # Whether to continue training an existing model.
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. if continue_training:
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output". dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
if backbone_type == "transformer": if backbone_type == "transformer":
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction. transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction. transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
@ -190,6 +192,7 @@ def run(_config,
weight_decay, weight_decay,
learning_rate, learning_rate,
continue_training, continue_training,
index_start,
dir_of_start_model, dir_of_start_model,
save_interval, save_interval,
augmentation, augmentation,
@ -312,6 +315,7 @@ def run(_config,
custom_objects = {"PatchEncoder": PatchEncoder, custom_objects = {"PatchEncoder": PatchEncoder,
"Patches": Patches}) "Patches": Patches})
else: else:
index_start = 0
if backbone_type == 'nontransformer': if backbone_type == 'nontransformer':
model = resnet50_unet(n_classes, model = resnet50_unet(n_classes,
input_height, input_height,
@ -410,7 +414,8 @@ def run(_config,
#validation_steps=1, # rs: only one batch?? #validation_steps=1, # rs: only one batch??
validation_steps=steps_val, validation_steps=steps_val,
epochs=n_epochs, epochs=n_epochs,
callbacks=callbacks) callbacks=callbacks,
initial_epoch=index_start)
#os.system('rm -rf '+dir_train_flowing) #os.system('rm -rf '+dir_train_flowing)
#os.system('rm -rf '+dir_eval_flowing) #os.system('rm -rf '+dir_eval_flowing)