From 1581094141a2eb8892fa58b09de7fe8500e73e08 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 4 Feb 2026 17:35:12 +0100 Subject: [PATCH] training: extend `index_start` to tasks classification and RO --- src/eynollah/training/train.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index de8cccd..168884a 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -423,11 +423,15 @@ def run(_config, #model.save(dir_output+'/'+'model'+'.h5') elif task=='classification': - model = resnet50_classifier(n_classes, - input_height, - input_width, - weight_decay, - pretraining) + if continue_training: + model = load_model(dir_of_start_model, compile=False) + else: + index_start = 0 + model = resnet50_classifier(n_classes, + input_height, + input_width, + weight_decay, + pretraining) model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate? @@ -453,7 +457,8 @@ def run(_config, verbose=1, epochs=n_epochs, metrics=[F1Score(average='macro', name='f1')], - callbacks=callbacks) + callbacks=callbacks, + initial_epoch=index_start) usable_checkpoints = np.flatnonzero(np.array(history['val_f1']) > f1_threshold_classification) if len(usable_checkpoints) >= 1: @@ -481,8 +486,15 @@ def run(_config, _log.info("ensemble model saved under '%s'", cp_path) elif task=='reading_order': - model = machine_based_reading_order_model( - n_classes, input_height, input_width, weight_decay, pretraining) + if continue_training: + model = load_model(dir_of_start_model, compile=False) + else: + index_start = 0 + model = machine_based_reading_order_model(n_classes, + input_height, + input_width, + weight_decay, + pretraining) dir_flow_train_imgs = os.path.join(dir_train, 'images') dir_flow_train_labels = os.path.join(dir_train, 'labels') @@ -495,7 +507,6 @@ def run(_config, #ls_test = os.listdir(dir_flow_train_labels) #f1score_tot = [0] - indexer_start = 0 model.compile(loss="binary_crossentropy", #optimizer=SGD(learning_rate=0.01, momentum=0.9), optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate? @@ -515,7 +526,8 @@ def run(_config, steps_per_epoch=num_rows / n_batch, verbose=1, epochs=n_epochs, - callbacks=callbacks) + callbacks=callbacks, + initial_epoch=index_start) ''' if f1score>f1score_tot[0]: f1score_tot[0] = f1score