mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: add config param add_ncc_loss for layout/binarization…
- add `metrics.metrics_superposition` and `metrics.Superposition` - if non-zero, mix configured loss with weighted nCC metric
This commit is contained in:
parent
c6d9dd7945
commit
7e06ab2c8c
2 changed files with 37 additions and 2 deletions
|
|
@ -3,7 +3,7 @@ import os
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras import backend as K
|
from tensorflow.keras import backend as K
|
||||||
from tensorflow.keras.metrics import Metric
|
from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get
|
||||||
from tensorflow.keras.initializers import Zeros
|
from tensorflow.keras.initializers import Zeros
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -369,6 +369,34 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
||||||
return (1 - jac) * smooth
|
return (1 - jac) * smooth
|
||||||
|
|
||||||
|
|
||||||
|
def metrics_superposition(*metrics, weights=None):
|
||||||
|
"""
|
||||||
|
return a single metric derived by adding all given metrics
|
||||||
|
|
||||||
|
default weights are uniform
|
||||||
|
"""
|
||||||
|
if weights is None:
|
||||||
|
weights = len(metrics) * [tf.constant(1.0)]
|
||||||
|
def mixed(y_true, y_pred):
|
||||||
|
results = []
|
||||||
|
for metric, weight in zip(metrics, weights):
|
||||||
|
results.append(metric(y_true, y_pred) * weight)
|
||||||
|
return tf.reduce_mean(tf.stack(results), 0)
|
||||||
|
mixed.__name__ = '/'.join(m.__name__ for m in metrics)
|
||||||
|
return mixed
|
||||||
|
|
||||||
|
|
||||||
|
class Superposition(MeanMetricWrapper):
|
||||||
|
def __init__(self, metrics, weights=None, dtype=None):
|
||||||
|
self._metrics = metrics
|
||||||
|
self._weights = weights
|
||||||
|
mixed = metrics_superposition(*metrics, weights=weights)
|
||||||
|
super().__init__(mixed, name=mixed.__name__, dtype=dtype)
|
||||||
|
def get_config(self):
|
||||||
|
return dict(metrics=self._metrics,
|
||||||
|
weights=self._weights,
|
||||||
|
**super().get_config())
|
||||||
|
|
||||||
class ConfusionMatrix(Metric):
|
class ConfusionMatrix(Metric):
|
||||||
def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32):
|
def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32):
|
||||||
super().__init__(name=name, dtype=dtype)
|
super().__init__(name=name, dtype=dtype)
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ from matplotlib import pyplot as plt # for plot_confusion_matrix
|
||||||
from .metrics import (
|
from .metrics import (
|
||||||
soft_dice_loss,
|
soft_dice_loss,
|
||||||
weighted_categorical_crossentropy,
|
weighted_categorical_crossentropy,
|
||||||
|
get as get_metric,
|
||||||
|
metrics_superposition,
|
||||||
ConfusionMatrix,
|
ConfusionMatrix,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
|
|
@ -306,6 +308,7 @@ def config_params():
|
||||||
if task in ["segmentation", "binarization"]:
|
if task in ["segmentation", "binarization"]:
|
||||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||||
|
add_ncc_loss = 0 # Add regression loss for number of connected components. When non-zero, use this as weight for the nCC term.
|
||||||
elif task == "classification":
|
elif task == "classification":
|
||||||
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||||
classification_classes_name = None # Dictionary of classification classes names.
|
classification_classes_name = None # Dictionary of classification classes names.
|
||||||
|
|
@ -416,6 +419,7 @@ def run(_config,
|
||||||
thetha=None,
|
thetha=None,
|
||||||
is_loss_soft_dice=False,
|
is_loss_soft_dice=False,
|
||||||
weighted_loss=False,
|
weighted_loss=False,
|
||||||
|
add_ncc_loss=None,
|
||||||
## if continue_training
|
## if continue_training
|
||||||
index_start=0,
|
index_start=0,
|
||||||
dir_of_start_model=None,
|
dir_of_start_model=None,
|
||||||
|
|
@ -605,7 +609,10 @@ def run(_config,
|
||||||
elif weighted_loss:
|
elif weighted_loss:
|
||||||
loss = weighted_categorical_crossentropy(weights)
|
loss = weighted_categorical_crossentropy(weights)
|
||||||
else:
|
else:
|
||||||
loss = 'categorical_crossentropy'
|
loss = get_metric('categorical_crossentropy')
|
||||||
|
if add_ncc_loss:
|
||||||
|
loss = metrics_superposition(loss, num_connected_components_regression(0.1),
|
||||||
|
weights=[1 - add_ncc_loss, add_ncc_loss])
|
||||||
metrics.append(num_connected_components_regression(0.1))
|
metrics.append(num_connected_components_regression(0.1))
|
||||||
metrics.append(MeanIoU(n_classes,
|
metrics.append(MeanIoU(n_classes,
|
||||||
name='iou',
|
name='iou',
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue