From 82d649061a7d932df25828081c01b25a6acae012 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Thu, 5 Feb 2026 11:57:38 +0100 Subject: [PATCH] training.train: fix F1 metric score setup --- src/eynollah/training/train.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 4aafcf2..effc920 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -34,7 +34,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf from tensorflow.keras.optimizers import SGD, Adam -from tensorflow.keras.metrics import MeanIoU +from tensorflow.keras.metrics import MeanIoU, F1Score from tensorflow.keras.models import load_model from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from sacred import Experiment @@ -427,8 +427,8 @@ def run(_config, model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate? - metrics=['accuracy']) - + metrics=['accuracy', F1Score(average='macro', name='f1')]) + list_classes = list(classification_classes_name.values()) trainXY = generate_data_from_folder_training( dir_train, n_batch, input_height, input_width, n_classes, list_classes) @@ -440,7 +440,8 @@ def run(_config, callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), SaveWeightsAfterSteps(0, dir_output, _config, monitor='val_f1', - save_best_only=True, mode='max')] + #save_best_only=True, # we need all for ensembling + mode='max')] history = model.fit(trainXY, steps_per_epoch=num_rows / n_batch, @@ -448,17 +449,17 @@ def run(_config, validation_data=testXY, verbose=1, epochs=n_epochs, - metrics=[F1Score(average='macro', name='f1')], callbacks=callbacks, initial_epoch=index_start) - usable_checkpoints = np.flatnonzero(np.array(history['val_f1']) > f1_threshold_classification) + usable_checkpoints = np.flatnonzero(np.array(history.history['val_f1']) > + f1_threshold_classification) if len(usable_checkpoints) >= 1: _log.info("averaging over usable checkpoints: %s", str(usable_checkpoints)) all_weights = [] for epoch in usable_checkpoints: - cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch)) - assert os.path.isdir(cp_path) + cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) + assert os.path.isdir(cp_path), cp_path model = load_model(cp_path, compile=False) all_weights.append(model.get_weights())