training: move nCC metric/loss to .metrics and rename…

- `num_connected_components_regression` → `connected_components_loss`
- move from training.train to training.metrics
This commit is contained in:
Robert Sachunsky 2026-02-28 20:01:49 +01:00
parent 361d40c064
commit e47653f684
2 changed files with 88 additions and 87 deletions

View file

@ -5,6 +5,7 @@ import tensorflow as tf
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get
from tensorflow.keras.initializers import Zeros from tensorflow.keras.initializers import Zeros
from tensorflow_addons.image import connected_components
import numpy as np import numpy as np
@ -439,3 +440,87 @@ class ConfusionMatrix(Metric):
def get_config(self): def get_config(self):
return dict(nlabels=self._nlabels, return dict(nlabels=self._nlabels,
**super().get_config()) **super().get_config())
def connected_components_loss(artificial=0):
"""
metric/loss function capturing the separability of segmentation maps
For both sides (true and predicted, resp.), computes
1. the argmax() of class-wise softmax input (i.e. the segmentation map)
2. the connected components (i.e. the instance label map)
3. the max() (i.e. the highest label = nr of components)
The original idea was to then calculate a regression formula
between those two targets. But it is insufficient to just
approximate the same number of components, for they might be
completely different (true components being merged, predicted
components splitting others). We really want to capture the
correspondence between those labels, which is localised.
For that we now calculate the label pairs and their counts.
Looking at the M,N incidence matrix, we want those counts
to be distributed orthogonally (ideally). So we compute a
singular value decomposition and compare the sum total of
singular values to the sum total of all label counts. The
rate of the two determines a measure of congruence.
Moreover, for the case of artificial boundary segments around
regions, optionally introduced by the training extractor to
represent segment identity in the loss (and removed at runtime):
Reduce this class to background as well.
"""
def metric(y_true, y_pred):
if artificial:
# convert artificial border class to background
y_true = y_true[:, :, :, :artificial]
y_pred = y_pred[:, :, :, :artificial]
# [B, H, W, C]
l_true = tf.math.argmax(y_true, axis=-1)
l_pred = tf.math.argmax(y_pred, axis=-1)
# [B, H, W]
c_true = tf.cast(connected_components(l_true), tf.int64)
c_pred = tf.cast(connected_components(l_pred), tf.int64)
# [B, H, W]
n_batch = y_true.shape[0]
C_true = tf.math.reduce_max(c_true, (1, 2)) + 1
C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1
MODULUS = tf.constant(2**22, tf.int64)
tf.debugging.assert_less(C_true, MODULUS,
message="cannot compare segments: too many connected components in GT")
tf.debugging.assert_less(C_pred, MODULUS,
message="cannot compare segments: too many connected components in prediction")
c_comb = MODULUS * c_pred + c_true
tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64),
message="overflow pairing components")
# [B, H, W]
# tf.unique does not support batch dim, so...
results = []
for c_comb, C_true, C_pred in zip(
tf.unstack(c_comb, num=n_batch),
tf.unstack(C_true, num=n_batch),
tf.unstack(C_pred, num=n_batch),
):
prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,)))
# [L]
#corr = tf.zeros([C_pred, C_true], tf.int32)
#corr[prod // 2**24, prod % 2**24] = count
corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1),
count, (C_pred, C_true))
corr = tf.cast(corr, tf.float32)
# [Cpred, Ctrue]
sgv = tf.linalg.svd(corr, compute_uv=False)
results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr))
return 1.0 - tf.reduce_mean(tf.stack(results), 0)
# c_true = tf.reshape(c_true, (n_batch, -1))
# c_pred = tf.reshape(c_pred, (n_batch, -1))
# # [B, H*W]
# n_true = tf.math.reduce_max(c_true, axis=1)
# n_pred = tf.math.reduce_max(c_pred, axis=1)
# # [B]
# diff = tf.cast(n_true - n_pred, tf.float32)
# return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1)
metric.__name__ = 'nCC'
metric._direction = 'down'
return metric

View file

@ -16,7 +16,6 @@ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStoppi
from tensorflow.keras.layers import StringLookup from tensorflow.keras.layers import StringLookup
from tensorflow.keras.utils import image_dataset_from_directory from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.backend import one_hot from tensorflow.keras.backend import one_hot
from tensorflow_addons.image import connected_components
from sacred import Experiment from sacred import Experiment
from sacred.config import create_captured_function from sacred.config import create_captured_function
@ -30,6 +29,7 @@ from .metrics import (
get as get_metric, get as get_metric,
metrics_superposition, metrics_superposition,
ConfusionMatrix, ConfusionMatrix,
connected_components_loss,
) )
from .models import ( from .models import (
PatchEncoder, PatchEncoder,
@ -83,90 +83,6 @@ def configuration():
except: except:
print("no GPU device available", file=sys.stderr) print("no GPU device available", file=sys.stderr)
def num_connected_components_regression(artificial=0):
"""
metric/loss function capturing the separability of segmentation maps
For both sides (true and predicted, resp.), computes
1. the argmax() of class-wise softmax input (i.e. the segmentation map)
2. the connected components (i.e. the instance label map)
3. the max() (i.e. the highest label = nr of components)
The original idea was to then calculate a regression formula
between those two targets. But it is insufficient to just
approximate the same number of components, for they might be
completely different (true components being merged, predicted
components splitting others). We really want to capture the
correspondence between those labels, which is localised.
For that we now calculate the label pairs and their counts.
Looking at the M,N incidence matrix, we want those counts
to be distributed orthogonally (ideally). So we compute a
singular value decomposition and compare the sum total of
singular values to the sum total of all label counts. The
rate of the two determines a measure of congruence.
Moreover, for the case of artificial boundary segments around
regions, optionally introduced by the training extractor to
represent segment identity in the loss (and removed at runtime):
Reduce this class to background as well.
"""
def metric(y_true, y_pred):
if artificial:
# convert artificial border class to background
y_true = y_true[:, :, :, :artificial]
y_pred = y_pred[:, :, :, :artificial]
# [B, H, W, C]
l_true = tf.math.argmax(y_true, axis=-1)
l_pred = tf.math.argmax(y_pred, axis=-1)
# [B, H, W]
c_true = tf.cast(connected_components(l_true), tf.int64)
c_pred = tf.cast(connected_components(l_pred), tf.int64)
# [B, H, W]
#n_batch = tf.shape(y_true)[0]
n_batch = y_true.shape[0]
C_true = tf.math.reduce_max(c_true, (1, 2)) + 1
C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1
MODULUS = tf.constant(2**22, tf.int64)
tf.debugging.assert_less(C_true, MODULUS,
message="cannot compare segments: too many connected components in GT")
tf.debugging.assert_less(C_pred, MODULUS,
message="cannot compare segments: too many connected components in prediction")
c_comb = MODULUS * c_pred + c_true
tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64),
message="overflow pairing components")
# [B, H, W]
# tf.unique does not support batch dim, so...
results = []
for c_comb, C_true, C_pred in zip(
tf.unstack(c_comb, num=n_batch),
tf.unstack(C_true, num=n_batch),
tf.unstack(C_pred, num=n_batch),
):
prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,)))
#tf.print(n_batch, tf.shape(prod), C_true, C_true)
# [L]
#corr = tf.zeros([C_pred, C_true], tf.int32)
#corr[prod // 2**24, prod % 2**24] = count
corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1),
count, (C_pred, C_true))
corr = tf.cast(corr, tf.float32)
# [Cpred, Ctrue]
sgv = tf.linalg.svd(corr, compute_uv=False)
results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr))
return 1.0 - tf.reduce_mean(tf.stack(results), 0)
# c_true = tf.reshape(c_true, (n_batch, -1))
# c_pred = tf.reshape(c_pred, (n_batch, -1))
# # [B, H*W]
# n_true = tf.math.reduce_max(c_true, axis=1)
# n_pred = tf.math.reduce_max(c_pred, axis=1)
# # [B]
# diff = tf.cast(n_true - n_pred, tf.float32)
# return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1)
metric.__name__ = 'nCC'
return metric
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor: def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
""" """
Implements training.inference.SBBPredict.visualize_model_output for TF Implements training.inference.SBBPredict.visualize_model_output for TF
@ -659,9 +575,9 @@ def run(_config,
else: else:
loss = get_metric('categorical_crossentropy') loss = get_metric('categorical_crossentropy')
if add_ncc_loss: if add_ncc_loss:
loss = metrics_superposition(loss, num_connected_components_regression(n_classes - 1), loss = metrics_superposition(loss, connected_components_loss(n_classes - 1),
weights=[1 - add_ncc_loss, add_ncc_loss]) weights=[1 - add_ncc_loss, add_ncc_loss])
metrics.append(num_connected_components_regression(n_classes - 1)) metrics.append(connected_components_loss(n_classes - 1))
metrics.append(MeanIoU(n_classes, metrics.append(MeanIoU(n_classes,
name='iou', name='iou',
ignore_class=0, ignore_class=0,