training: add config param add_ncc_loss for layout/binarization…

- add `metrics.metrics_superposition` and `metrics.Superposition`
- if non-zero, mix configured loss with weighted nCC metric
This commit is contained in:
Robert Sachunsky 2026-02-27 12:55:15 +01:00
parent c6d9dd7945
commit 7e06ab2c8c
2 changed files with 37 additions and 2 deletions

View file

@ -3,7 +3,7 @@ import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
from tensorflow.keras.metrics import Metric from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get
from tensorflow.keras.initializers import Zeros from tensorflow.keras.initializers import Zeros
import numpy as np import numpy as np
@ -369,6 +369,34 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100):
return (1 - jac) * smooth return (1 - jac) * smooth
def metrics_superposition(*metrics, weights=None):
"""
return a single metric derived by adding all given metrics
default weights are uniform
"""
if weights is None:
weights = len(metrics) * [tf.constant(1.0)]
def mixed(y_true, y_pred):
results = []
for metric, weight in zip(metrics, weights):
results.append(metric(y_true, y_pred) * weight)
return tf.reduce_mean(tf.stack(results), 0)
mixed.__name__ = '/'.join(m.__name__ for m in metrics)
return mixed
class Superposition(MeanMetricWrapper):
def __init__(self, metrics, weights=None, dtype=None):
self._metrics = metrics
self._weights = weights
mixed = metrics_superposition(*metrics, weights=weights)
super().__init__(mixed, name=mixed.__name__, dtype=dtype)
def get_config(self):
return dict(metrics=self._metrics,
weights=self._weights,
**super().get_config())
class ConfusionMatrix(Metric): class ConfusionMatrix(Metric):
def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32): def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32):
super().__init__(name=name, dtype=dtype) super().__init__(name=name, dtype=dtype)

View file

@ -27,6 +27,8 @@ from matplotlib import pyplot as plt # for plot_confusion_matrix
from .metrics import ( from .metrics import (
soft_dice_loss, soft_dice_loss,
weighted_categorical_crossentropy, weighted_categorical_crossentropy,
get as get_metric,
metrics_superposition,
ConfusionMatrix, ConfusionMatrix,
) )
from .models import ( from .models import (
@ -306,6 +308,7 @@ def config_params():
if task in ["segmentation", "binarization"]: if task in ["segmentation", "binarization"]:
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
add_ncc_loss = 0 # Add regression loss for number of connected components. When non-zero, use this as weight for the nCC term.
elif task == "classification": elif task == "classification":
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output. f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
classification_classes_name = None # Dictionary of classification classes names. classification_classes_name = None # Dictionary of classification classes names.
@ -416,6 +419,7 @@ def run(_config,
thetha=None, thetha=None,
is_loss_soft_dice=False, is_loss_soft_dice=False,
weighted_loss=False, weighted_loss=False,
add_ncc_loss=None,
## if continue_training ## if continue_training
index_start=0, index_start=0,
dir_of_start_model=None, dir_of_start_model=None,
@ -605,7 +609,10 @@ def run(_config,
elif weighted_loss: elif weighted_loss:
loss = weighted_categorical_crossentropy(weights) loss = weighted_categorical_crossentropy(weights)
else: else:
loss = 'categorical_crossentropy' loss = get_metric('categorical_crossentropy')
if add_ncc_loss:
loss = metrics_superposition(loss, num_connected_components_regression(0.1),
weights=[1 - add_ncc_loss, add_ncc_loss])
metrics.append(num_connected_components_regression(0.1)) metrics.append(num_connected_components_regression(0.1))
metrics.append(MeanIoU(n_classes, metrics.append(MeanIoU(n_classes,
name='iou', name='iou',