diff --git a/src/eynollah/training/metrics.py b/src/eynollah/training/metrics.py index 5955888..60ac421 100644 --- a/src/eynollah/training/metrics.py +++ b/src/eynollah/training/metrics.py @@ -3,7 +3,7 @@ import os os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf from tensorflow.keras import backend as K -from tensorflow.keras.metrics import Metric +from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get from tensorflow.keras.initializers import Zeros import numpy as np @@ -369,6 +369,34 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100): return (1 - jac) * smooth +def metrics_superposition(*metrics, weights=None): + """ + return a single metric derived by adding all given metrics + + default weights are uniform + """ + if weights is None: + weights = len(metrics) * [tf.constant(1.0)] + def mixed(y_true, y_pred): + results = [] + for metric, weight in zip(metrics, weights): + results.append(metric(y_true, y_pred) * weight) + return tf.reduce_mean(tf.stack(results), 0) + mixed.__name__ = '/'.join(m.__name__ for m in metrics) + return mixed + + +class Superposition(MeanMetricWrapper): + def __init__(self, metrics, weights=None, dtype=None): + self._metrics = metrics + self._weights = weights + mixed = metrics_superposition(*metrics, weights=weights) + super().__init__(mixed, name=mixed.__name__, dtype=dtype) + def get_config(self): + return dict(metrics=self._metrics, + weights=self._weights, + **super().get_config()) + class ConfusionMatrix(Metric): def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32): super().__init__(name=name, dtype=dtype) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 4d997e5..a12b9c7 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -27,6 +27,8 @@ from matplotlib import pyplot as plt # for plot_confusion_matrix from .metrics import ( soft_dice_loss, weighted_categorical_crossentropy, + get as get_metric, + metrics_superposition, ConfusionMatrix, ) from .models import ( @@ -306,6 +308,7 @@ def config_params(): if task in ["segmentation", "binarization"]: is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. + add_ncc_loss = 0 # Add regression loss for number of connected components. When non-zero, use this as weight for the nCC term. elif task == "classification": f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output. classification_classes_name = None # Dictionary of classification classes names. @@ -416,6 +419,7 @@ def run(_config, thetha=None, is_loss_soft_dice=False, weighted_loss=False, + add_ncc_loss=None, ## if continue_training index_start=0, dir_of_start_model=None, @@ -605,7 +609,10 @@ def run(_config, elif weighted_loss: loss = weighted_categorical_crossentropy(weights) else: - loss = 'categorical_crossentropy' + loss = get_metric('categorical_crossentropy') + if add_ncc_loss: + loss = metrics_superposition(loss, num_connected_components_regression(0.1), + weights=[1 - add_ncc_loss, add_ncc_loss]) metrics.append(num_connected_components_regression(0.1)) metrics.append(MeanIoU(n_classes, name='iou',