mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-31 01:59:27 +02:00
train params: drop reload_weights, re-use dir_of_start_model…
- drop ad-hoc configuration parameter `reload_weights` (used for conversion/export of models for inference, to be replaced by extra CLI) - re-interprete `dir_of_start_model` to also load weights if not `continue_training`
This commit is contained in:
parent
093030f503
commit
62b55a3809
1 changed files with 14 additions and 42 deletions
|
|
@ -347,10 +347,9 @@ def config_params():
|
||||||
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
|
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
|
||||||
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
|
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
|
||||||
save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}")
|
save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}")
|
||||||
reload_weights = False # Set true to build new model from config, load weights from dir_of_start_model, save under dir_output and exit.
|
|
||||||
continue_training = False # Whether to continue training an existing model.
|
continue_training = False # Whether to continue training an existing model.
|
||||||
|
dir_of_start_model = '' # Directory of model checkpoint to load to continue training or load weights from. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
|
||||||
if continue_training:
|
if continue_training:
|
||||||
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
|
|
||||||
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
|
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
|
||||||
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
|
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
|
||||||
|
|
||||||
|
|
@ -371,7 +370,6 @@ def run(_config,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
continue_training,
|
continue_training,
|
||||||
reload_weights,
|
|
||||||
save_interval,
|
save_interval,
|
||||||
augmentation,
|
augmentation,
|
||||||
# dependent config keys need a default,
|
# dependent config keys need a default,
|
||||||
|
|
@ -475,6 +473,9 @@ def run(_config,
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
model = get_model(_config, _log)
|
model = get_model(_config, _log)
|
||||||
|
if dir_of_start_model:
|
||||||
|
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||||
|
_log.info("reloaded weights from %s", dir_of_start_model)
|
||||||
|
|
||||||
assert model is not None
|
assert model is not None
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
|
|
@ -505,16 +506,6 @@ def run(_config,
|
||||||
optimizer=Adam(learning_rate=learning_rate),
|
optimizer=Adam(learning_rate=learning_rate),
|
||||||
metrics=metrics)
|
metrics=metrics)
|
||||||
|
|
||||||
if reload_weights:
|
|
||||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
|
||||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
|
||||||
#model.save(dir_save, include_optimizer=False)
|
|
||||||
model.export(dir_save)
|
|
||||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
|
||||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not data_is_provided:
|
if not data_is_provided:
|
||||||
# first create a directory in output for both training and evaluations
|
# first create a directory in output for both training and evaluations
|
||||||
# in order to flow data from these directories.
|
# in order to flow data from these directories.
|
||||||
|
|
@ -656,6 +647,10 @@ def run(_config,
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
model = get_model(_config, _log)
|
model = get_model(_config, _log)
|
||||||
|
if dir_of_start_model:
|
||||||
|
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||||
|
_log.info("reloaded weights from %s", dir_of_start_model)
|
||||||
|
|
||||||
#initial_learning_rate = 1e-4
|
#initial_learning_rate = 1e-4
|
||||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||||
#alpha = 0.01
|
#alpha = 0.01
|
||||||
|
|
@ -666,16 +661,6 @@ def run(_config,
|
||||||
|
|
||||||
#print(model.summary())
|
#print(model.summary())
|
||||||
|
|
||||||
if reload_weights:
|
|
||||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
|
||||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
|
||||||
#model.save(dir_save, include_optimizer=False)
|
|
||||||
model.export(dir_save)
|
|
||||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
|
||||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
|
||||||
return
|
|
||||||
|
|
||||||
# todo: use Dataset.map() on Dataset.list_files()
|
# todo: use Dataset.map() on Dataset.list_files()
|
||||||
def get_dataset(dir_img, dir_lab):
|
def get_dataset(dir_img, dir_lab):
|
||||||
def gen():
|
def gen():
|
||||||
|
|
@ -718,20 +703,14 @@ def run(_config,
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
model = get_model(_config, _log)
|
model = get_model(_config, _log)
|
||||||
|
if dir_of_start_model:
|
||||||
|
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||||
|
_log.info("reloaded weights from %s", dir_of_start_model)
|
||||||
|
|
||||||
model.compile(loss='categorical_crossentropy',
|
model.compile(loss='categorical_crossentropy',
|
||||||
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
|
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
|
||||||
metrics=['accuracy', F1Score(average='macro', name='f1')])
|
metrics=['accuracy', F1Score(average='macro', name='f1')])
|
||||||
|
|
||||||
if reload_weights:
|
|
||||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
|
||||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
|
||||||
model.save(dir_save, include_optimizer=False)
|
|
||||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
|
||||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
|
||||||
return
|
|
||||||
|
|
||||||
list_classes = list(classification_classes_name.values())
|
list_classes = list(classification_classes_name.values())
|
||||||
data_args = dict(label_mode="categorical",
|
data_args = dict(label_mode="categorical",
|
||||||
class_names=list_classes,
|
class_names=list_classes,
|
||||||
|
|
@ -770,6 +749,9 @@ def run(_config,
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
model = get_model(_config, _log)
|
model = get_model(_config, _log)
|
||||||
|
if dir_of_start_model:
|
||||||
|
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
||||||
|
_log.info("reloaded weights from %s", dir_of_start_model)
|
||||||
|
|
||||||
#f1score_tot = [0]
|
#f1score_tot = [0]
|
||||||
model.compile(loss="binary_crossentropy",
|
model.compile(loss="binary_crossentropy",
|
||||||
|
|
@ -777,16 +759,6 @@ def run(_config,
|
||||||
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
|
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
if reload_weights:
|
|
||||||
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
|
|
||||||
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
|
|
||||||
#model.save(dir_save, include_optimizer=False)
|
|
||||||
model.export(dir_save)
|
|
||||||
with open(os.path.join(dir_save, "config.json"), "w") as fp:
|
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
|
||||||
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
|
|
||||||
return
|
|
||||||
|
|
||||||
dir_flow_train_imgs = os.path.join(dir_train, 'images')
|
dir_flow_train_imgs = os.path.join(dir_train, 'images')
|
||||||
dir_flow_train_labels = os.path.join(dir_train, 'labels')
|
dir_flow_train_labels = os.path.join(dir_train, 'labels')
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue