train params: drop reload_weights, re-use dir_of_start_model

- drop ad-hoc configuration parameter `reload_weights`
  (used for conversion/export of models for inference,
   to be replaced by extra CLI)
- re-interprete `dir_of_start_model` to also load weights
  if not `continue_training`
This commit is contained in:
Robert Sachunsky 2026-05-28 17:42:55 +02:00
parent 093030f503
commit 62b55a3809

View file

@ -347,10 +347,9 @@ def config_params():
dir_output = None # Directory where the augmented training data and the model checkpoints will be saved.
pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder.
save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}")
reload_weights = False # Set true to build new model from config, load weights from dir_of_start_model, save under dir_output and exit.
continue_training = False # Whether to continue training an existing model.
dir_of_start_model = '' # Directory of model checkpoint to load to continue training or load weights from. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
if continue_training:
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
@ -371,7 +370,6 @@ def run(_config,
weight_decay,
learning_rate,
continue_training,
reload_weights,
save_interval,
augmentation,
# dependent config keys need a default,
@ -475,6 +473,9 @@ def run(_config,
else:
index_start = 0
model = get_model(_config, _log)
if dir_of_start_model:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
_log.info("reloaded weights from %s", dir_of_start_model)
assert model is not None
#if you want to see the model structure just uncomment model summary.
@ -505,16 +506,6 @@ def run(_config,
optimizer=Adam(learning_rate=learning_rate),
metrics=metrics)
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
#model.save(dir_save, include_optimizer=False)
model.export(dir_save)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
return
if not data_is_provided:
# first create a directory in output for both training and evaluations
# in order to flow data from these directories.
@ -656,6 +647,10 @@ def run(_config,
else:
index_start = 0
model = get_model(_config, _log)
if dir_of_start_model:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
_log.info("reloaded weights from %s", dir_of_start_model)
#initial_learning_rate = 1e-4
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
#alpha = 0.01
@ -666,16 +661,6 @@ def run(_config,
#print(model.summary())
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
#model.save(dir_save, include_optimizer=False)
model.export(dir_save)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
return
# todo: use Dataset.map() on Dataset.list_files()
def get_dataset(dir_img, dir_lab):
def gen():
@ -718,20 +703,14 @@ def run(_config,
else:
index_start = 0
model = get_model(_config, _log)
if dir_of_start_model:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
_log.info("reloaded weights from %s", dir_of_start_model)
model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
metrics=['accuracy', F1Score(average='macro', name='f1')])
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
model.save(dir_save, include_optimizer=False)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
return
list_classes = list(classification_classes_name.values())
data_args = dict(label_mode="categorical",
class_names=list_classes,
@ -770,6 +749,9 @@ def run(_config,
else:
index_start = 0
model = get_model(_config, _log)
if dir_of_start_model:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
_log.info("reloaded weights from %s", dir_of_start_model)
#f1score_tot = [0]
model.compile(loss="binary_crossentropy",
@ -777,16 +759,6 @@ def run(_config,
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
metrics=['accuracy'])
if reload_weights:
model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial()
dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model)))
#model.save(dir_save, include_optimizer=False)
model.export(dir_save)
with open(os.path.join(dir_save, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("reloaded model from %s to %s", dir_of_start_model, dir_save)
return
dir_flow_train_imgs = os.path.join(dir_train, 'images')
dir_flow_train_labels = os.path.join(dir_train, 'labels')