mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: re-instate index_start, reflect cfg dependency
- `index_start`: re-introduce cfg key, pass to Keras `Model.fit` as `initial_epoch` - make config keys `index_start` and `dir_of_start_model` dependent on `continue_training` - improve description
This commit is contained in:
parent
25153ad307
commit
e85003db4a
1 changed files with 10 additions and 5 deletions
|
|
@ -157,10 +157,12 @@ def config_params():
|
||||||
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
|
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
|
||||||
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
|
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
|
||||||
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
|
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
|
||||||
save_interval = None # frequency for writing model checkpoints (nonzero integer for number of batches, or zero for epoch)
|
save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}")
|
||||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
continue_training = False # Whether to continue training an existing model.
|
||||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
if continue_training:
|
||||||
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
|
||||||
|
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
|
||||||
|
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
|
||||||
if backbone_type == "transformer":
|
if backbone_type == "transformer":
|
||||||
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
|
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
|
||||||
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
|
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
|
||||||
|
|
@ -190,6 +192,7 @@ def run(_config,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
continue_training,
|
continue_training,
|
||||||
|
index_start,
|
||||||
dir_of_start_model,
|
dir_of_start_model,
|
||||||
save_interval,
|
save_interval,
|
||||||
augmentation,
|
augmentation,
|
||||||
|
|
@ -312,6 +315,7 @@ def run(_config,
|
||||||
custom_objects = {"PatchEncoder": PatchEncoder,
|
custom_objects = {"PatchEncoder": PatchEncoder,
|
||||||
"Patches": Patches})
|
"Patches": Patches})
|
||||||
else:
|
else:
|
||||||
|
index_start = 0
|
||||||
if backbone_type == 'nontransformer':
|
if backbone_type == 'nontransformer':
|
||||||
model = resnet50_unet(n_classes,
|
model = resnet50_unet(n_classes,
|
||||||
input_height,
|
input_height,
|
||||||
|
|
@ -410,7 +414,8 @@ def run(_config,
|
||||||
#validation_steps=1, # rs: only one batch??
|
#validation_steps=1, # rs: only one batch??
|
||||||
validation_steps=steps_val,
|
validation_steps=steps_val,
|
||||||
epochs=n_epochs,
|
epochs=n_epochs,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks,
|
||||||
|
initial_epoch=index_start)
|
||||||
|
|
||||||
#os.system('rm -rf '+dir_train_flowing)
|
#os.system('rm -rf '+dir_train_flowing)
|
||||||
#os.system('rm -rf '+dir_eval_flowing)
|
#os.system('rm -rf '+dir_eval_flowing)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue