diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 2cb42b6..f4cf08b 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -347,10 +347,9 @@ def config_params(): dir_output = None # Directory where the augmented training data and the model checkpoints will be saved. pretraining = False # Set to true to (down)load pretrained weights of ResNet50 encoder. save_interval = None # frequency for writing model checkpoints (positive integer for number of batches saved under "model_step_{batch:04d}", otherwise epoch saved under "model_{epoch:02d}") - reload_weights = False # Set true to build new model from config, load weights from dir_of_start_model, save under dir_output and exit. continue_training = False # Whether to continue training an existing model. + dir_of_start_model = '' # Directory of model checkpoint to load to continue training or load weights from. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".) if continue_training: - dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".) index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.) data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run). @@ -371,7 +370,6 @@ def run(_config, weight_decay, learning_rate, continue_training, - reload_weights, save_interval, augmentation, # dependent config keys need a default, @@ -475,6 +473,9 @@ def run(_config, else: index_start = 0 model = get_model(_config, _log) + if dir_of_start_model: + model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() + _log.info("reloaded weights from %s", dir_of_start_model) assert model is not None #if you want to see the model structure just uncomment model summary. @@ -505,16 +506,6 @@ def run(_config, optimizer=Adam(learning_rate=learning_rate), metrics=metrics) - if reload_weights: - model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() - dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model))) - #model.save(dir_save, include_optimizer=False) - model.export(dir_save) - with open(os.path.join(dir_save, "config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON - _log.info("reloaded model from %s to %s", dir_of_start_model, dir_save) - return - if not data_is_provided: # first create a directory in output for both training and evaluations # in order to flow data from these directories. @@ -656,6 +647,10 @@ def run(_config, else: index_start = 0 model = get_model(_config, _log) + if dir_of_start_model: + model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() + _log.info("reloaded weights from %s", dir_of_start_model) + #initial_learning_rate = 1e-4 #decay_steps = int (n_epochs * ( len_dataset / n_batch )) #alpha = 0.01 @@ -666,16 +661,6 @@ def run(_config, #print(model.summary()) - if reload_weights: - model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() - dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model))) - #model.save(dir_save, include_optimizer=False) - model.export(dir_save) - with open(os.path.join(dir_save, "config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON - _log.info("reloaded model from %s to %s", dir_of_start_model, dir_save) - return - # todo: use Dataset.map() on Dataset.list_files() def get_dataset(dir_img, dir_lab): def gen(): @@ -718,20 +703,14 @@ def run(_config, else: index_start = 0 model = get_model(_config, _log) + if dir_of_start_model: + model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() + _log.info("reloaded weights from %s", dir_of_start_model) model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate? metrics=['accuracy', F1Score(average='macro', name='f1')]) - if reload_weights: - model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() - dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model))) - model.save(dir_save, include_optimizer=False) - with open(os.path.join(dir_save, "config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON - _log.info("reloaded model from %s to %s", dir_of_start_model, dir_save) - return - list_classes = list(classification_classes_name.values()) data_args = dict(label_mode="categorical", class_names=list_classes, @@ -770,6 +749,9 @@ def run(_config, else: index_start = 0 model = get_model(_config, _log) + if dir_of_start_model: + model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() + _log.info("reloaded weights from %s", dir_of_start_model) #f1score_tot = [0] model.compile(loss="binary_crossentropy", @@ -777,16 +759,6 @@ def run(_config, optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate? metrics=['accuracy']) - if reload_weights: - model.load_weights(dir_of_start_model).assert_existing_objects_matched().expect_partial() - dir_save = os.path.join(dir_output, os.path.basename(os.path.normpath(dir_of_start_model))) - #model.save(dir_save, include_optimizer=False) - model.export(dir_save) - with open(os.path.join(dir_save, "config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON - _log.info("reloaded model from %s to %s", dir_of_start_model, dir_save) - return - dir_flow_train_imgs = os.path.join(dir_train, 'images') dir_flow_train_labels = os.path.join(dir_train, 'labels')