mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: re-instate index_start, reflect cfg dependency
- `index_start`: re-introduce cfg key, pass to Keras `Model.fit` as `initial_epoch` - make config keys `index_start` and `dir_of_start_model` dependent on `continue_training` - improve description
This commit is contained in:
parent
25153ad307
commit
e85003db4a
1 changed files with 10 additions and 5 deletions
|
|
@ -157,10 +157,12 @@ def config_params():
|
|||
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
|
||||
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
|
||||
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
|
||||
save_interval = None # frequency for writing model checkpoints (nonzero integer for number of batches, or zero for epoch)
|
||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
||||
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
||||
save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}")
|
||||
continue_training = False # Whether to continue training an existing model.
|
||||
if continue_training:
|
||||
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
|
||||
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
|
||||
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
|
||||
if backbone_type == "transformer":
|
||||
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
|
||||
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
|
||||
|
|
@ -190,6 +192,7 @@ def run(_config,
|
|||
weight_decay,
|
||||
learning_rate,
|
||||
continue_training,
|
||||
index_start,
|
||||
dir_of_start_model,
|
||||
save_interval,
|
||||
augmentation,
|
||||
|
|
@ -312,6 +315,7 @@ def run(_config,
|
|||
custom_objects = {"PatchEncoder": PatchEncoder,
|
||||
"Patches": Patches})
|
||||
else:
|
||||
index_start = 0
|
||||
if backbone_type == 'nontransformer':
|
||||
model = resnet50_unet(n_classes,
|
||||
input_height,
|
||||
|
|
@ -410,7 +414,8 @@ def run(_config,
|
|||
#validation_steps=1, # rs: only one batch??
|
||||
validation_steps=steps_val,
|
||||
epochs=n_epochs,
|
||||
callbacks=callbacks)
|
||||
callbacks=callbacks,
|
||||
initial_epoch=index_start)
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
#os.system('rm -rf '+dir_eval_flowing)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue