mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training: add config param add_ncc_loss for layout/binarization…
- add `metrics.metrics_superposition` and `metrics.Superposition` - if non-zero, mix configured loss with weighted nCC metric
This commit is contained in:
parent
c6d9dd7945
commit
7e06ab2c8c
2 changed files with 37 additions and 2 deletions
|
|
@ -3,7 +3,7 @@ import os
|
|||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import backend as K
|
||||
from tensorflow.keras.metrics import Metric
|
||||
from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get
|
||||
from tensorflow.keras.initializers import Zeros
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -369,6 +369,34 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
|||
return (1 - jac) * smooth
|
||||
|
||||
|
||||
def metrics_superposition(*metrics, weights=None):
|
||||
"""
|
||||
return a single metric derived by adding all given metrics
|
||||
|
||||
default weights are uniform
|
||||
"""
|
||||
if weights is None:
|
||||
weights = len(metrics) * [tf.constant(1.0)]
|
||||
def mixed(y_true, y_pred):
|
||||
results = []
|
||||
for metric, weight in zip(metrics, weights):
|
||||
results.append(metric(y_true, y_pred) * weight)
|
||||
return tf.reduce_mean(tf.stack(results), 0)
|
||||
mixed.__name__ = '/'.join(m.__name__ for m in metrics)
|
||||
return mixed
|
||||
|
||||
|
||||
class Superposition(MeanMetricWrapper):
|
||||
def __init__(self, metrics, weights=None, dtype=None):
|
||||
self._metrics = metrics
|
||||
self._weights = weights
|
||||
mixed = metrics_superposition(*metrics, weights=weights)
|
||||
super().__init__(mixed, name=mixed.__name__, dtype=dtype)
|
||||
def get_config(self):
|
||||
return dict(metrics=self._metrics,
|
||||
weights=self._weights,
|
||||
**super().get_config())
|
||||
|
||||
class ConfusionMatrix(Metric):
|
||||
def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ from matplotlib import pyplot as plt # for plot_confusion_matrix
|
|||
from .metrics import (
|
||||
soft_dice_loss,
|
||||
weighted_categorical_crossentropy,
|
||||
get as get_metric,
|
||||
metrics_superposition,
|
||||
ConfusionMatrix,
|
||||
)
|
||||
from .models import (
|
||||
|
|
@ -306,6 +308,7 @@ def config_params():
|
|||
if task in ["segmentation", "binarization"]:
|
||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||
add_ncc_loss = 0 # Add regression loss for number of connected components. When non-zero, use this as weight for the nCC term.
|
||||
elif task == "classification":
|
||||
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||
classification_classes_name = None # Dictionary of classification classes names.
|
||||
|
|
@ -416,6 +419,7 @@ def run(_config,
|
|||
thetha=None,
|
||||
is_loss_soft_dice=False,
|
||||
weighted_loss=False,
|
||||
add_ncc_loss=None,
|
||||
## if continue_training
|
||||
index_start=0,
|
||||
dir_of_start_model=None,
|
||||
|
|
@ -605,7 +609,10 @@ def run(_config,
|
|||
elif weighted_loss:
|
||||
loss = weighted_categorical_crossentropy(weights)
|
||||
else:
|
||||
loss = 'categorical_crossentropy'
|
||||
loss = get_metric('categorical_crossentropy')
|
||||
if add_ncc_loss:
|
||||
loss = metrics_superposition(loss, num_connected_components_regression(0.1),
|
||||
weights=[1 - add_ncc_loss, add_ncc_loss])
|
||||
metrics.append(num_connected_components_regression(0.1))
|
||||
metrics.append(MeanIoU(n_classes,
|
||||
name='iou',
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue