mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.train: fix F1 metric score setup
This commit is contained in:
parent
5c7801a1d6
commit
82d649061a
1 changed files with 9 additions and 8 deletions
|
|
@ -34,7 +34,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.optimizers import SGD, Adam
|
||||
from tensorflow.keras.metrics import MeanIoU
|
||||
from tensorflow.keras.metrics import MeanIoU, F1Score
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||
from sacred import Experiment
|
||||
|
|
@ -427,8 +427,8 @@ def run(_config,
|
|||
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
|
||||
metrics=['accuracy'])
|
||||
|
||||
metrics=['accuracy', F1Score(average='macro', name='f1')])
|
||||
|
||||
list_classes = list(classification_classes_name.values())
|
||||
trainXY = generate_data_from_folder_training(
|
||||
dir_train, n_batch, input_height, input_width, n_classes, list_classes)
|
||||
|
|
@ -440,7 +440,8 @@ def run(_config,
|
|||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||
SaveWeightsAfterSteps(0, dir_output, _config,
|
||||
monitor='val_f1',
|
||||
save_best_only=True, mode='max')]
|
||||
#save_best_only=True, # we need all for ensembling
|
||||
mode='max')]
|
||||
|
||||
history = model.fit(trainXY,
|
||||
steps_per_epoch=num_rows / n_batch,
|
||||
|
|
@ -448,17 +449,17 @@ def run(_config,
|
|||
validation_data=testXY,
|
||||
verbose=1,
|
||||
epochs=n_epochs,
|
||||
metrics=[F1Score(average='macro', name='f1')],
|
||||
callbacks=callbacks,
|
||||
initial_epoch=index_start)
|
||||
|
||||
usable_checkpoints = np.flatnonzero(np.array(history['val_f1']) > f1_threshold_classification)
|
||||
usable_checkpoints = np.flatnonzero(np.array(history.history['val_f1']) >
|
||||
f1_threshold_classification)
|
||||
if len(usable_checkpoints) >= 1:
|
||||
_log.info("averaging over usable checkpoints: %s", str(usable_checkpoints))
|
||||
all_weights = []
|
||||
for epoch in usable_checkpoints:
|
||||
cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch))
|
||||
assert os.path.isdir(cp_path)
|
||||
cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
|
||||
assert os.path.isdir(cp_path), cp_path
|
||||
model = load_model(cp_path, compile=False)
|
||||
all_weights.append(model.get_weights())
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue