training.train: fix F1 metric score setup

This commit is contained in:
Robert Sachunsky 2026-02-05 11:57:38 +01:00
parent 5c7801a1d6
commit 82d649061a

View file

@ -34,7 +34,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.metrics import MeanIoU, F1Score
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from sacred import Experiment
@ -427,8 +427,8 @@ def run(_config,
model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
metrics=['accuracy'])
metrics=['accuracy', F1Score(average='macro', name='f1')])
list_classes = list(classification_classes_name.values())
trainXY = generate_data_from_folder_training(
dir_train, n_batch, input_height, input_width, n_classes, list_classes)
@ -440,7 +440,8 @@ def run(_config,
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config,
monitor='val_f1',
save_best_only=True, mode='max')]
#save_best_only=True, # we need all for ensembling
mode='max')]
history = model.fit(trainXY,
steps_per_epoch=num_rows / n_batch,
@ -448,17 +449,17 @@ def run(_config,
validation_data=testXY,
verbose=1,
epochs=n_epochs,
metrics=[F1Score(average='macro', name='f1')],
callbacks=callbacks,
initial_epoch=index_start)
usable_checkpoints = np.flatnonzero(np.array(history['val_f1']) > f1_threshold_classification)
usable_checkpoints = np.flatnonzero(np.array(history.history['val_f1']) >
f1_threshold_classification)
if len(usable_checkpoints) >= 1:
_log.info("averaging over usable checkpoints: %s", str(usable_checkpoints))
all_weights = []
for epoch in usable_checkpoints:
cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch))
assert os.path.isdir(cp_path)
cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
assert os.path.isdir(cp_path), cp_path
model = load_model(cp_path, compile=False)
all_weights.append(model.get_weights())