training: improve nCC metric/loss - measure localized congruence…

- instead of just comparing the number of connected components,
  calculate the GT/pred label incidence matrix and retrieve the
  share of singular values (i.e. nearly diagonal under reordering)
  over total counts as similarity score
- also, suppress artificial class in that
This commit is contained in:
Robert Sachunsky 2026-02-28 19:44:10 +01:00
parent 7e06ab2c8c
commit 361d40c064

View file

@ -83,7 +83,7 @@ def configuration():
except:
print("no GPU device available", file=sys.stderr)
def num_connected_components_regression(alpha: float):
def num_connected_components_regression(artificial=0):
"""
metric/loss function capturing the separability of segmentation maps
@ -92,29 +92,77 @@ def num_connected_components_regression(alpha: float):
2. the connected components (i.e. the instance label map)
3. the max() (i.e. the highest label = nr of components)
Then calculates a regression formula between those two targets:
- overall mean squared (to incentivise exact fit)
- additive component (to incentivise more over less segments;
this prevents neighbours of spilling into each other;
oversegmentation is usually not as bad as undersegmentation)
The original idea was to then calculate a regression formula
between those two targets. But it is insufficient to just
approximate the same number of components, for they might be
completely different (true components being merged, predicted
components splitting others). We really want to capture the
correspondence between those labels, which is localised.
For that we now calculate the label pairs and their counts.
Looking at the M,N incidence matrix, we want those counts
to be distributed orthogonally (ideally). So we compute a
singular value decomposition and compare the sum total of
singular values to the sum total of all label counts. The
rate of the two determines a measure of congruence.
Moreover, for the case of artificial boundary segments around
regions, optionally introduced by the training extractor to
represent segment identity in the loss (and removed at runtime):
Reduce this class to background as well.
"""
def metric(y_true, y_pred):
if artificial:
# convert artificial border class to background
y_true = y_true[:, :, :, :artificial]
y_pred = y_pred[:, :, :, :artificial]
# [B, H, W, C]
l_true = tf.math.argmax(y_true, axis=-1)
l_pred = tf.math.argmax(y_pred, axis=-1)
# [B, H, W]
c_true = connected_components(l_true)
c_pred = connected_components(l_pred)
c_true = tf.cast(connected_components(l_true), tf.int64)
c_pred = tf.cast(connected_components(l_pred), tf.int64)
# [B, H, W]
n_batch = tf.shape(y_true)[0]
c_true = tf.reshape(c_true, (n_batch, -1))
c_pred = tf.reshape(c_pred, (n_batch, -1))
# [B, H*W]
n_true = tf.math.reduce_max(c_true, axis=1)
n_pred = tf.math.reduce_max(c_pred, axis=1)
# [B]
diff = tf.cast(n_true - n_pred, tf.float32)
return tf.reduce_mean(tf.math.sqrt(tf.math.square(diff) + alpha * diff), axis=-1)
#n_batch = tf.shape(y_true)[0]
n_batch = y_true.shape[0]
C_true = tf.math.reduce_max(c_true, (1, 2)) + 1
C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1
MODULUS = tf.constant(2**22, tf.int64)
tf.debugging.assert_less(C_true, MODULUS,
message="cannot compare segments: too many connected components in GT")
tf.debugging.assert_less(C_pred, MODULUS,
message="cannot compare segments: too many connected components in prediction")
c_comb = MODULUS * c_pred + c_true
tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64),
message="overflow pairing components")
# [B, H, W]
# tf.unique does not support batch dim, so...
results = []
for c_comb, C_true, C_pred in zip(
tf.unstack(c_comb, num=n_batch),
tf.unstack(C_true, num=n_batch),
tf.unstack(C_pred, num=n_batch),
):
prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,)))
#tf.print(n_batch, tf.shape(prod), C_true, C_true)
# [L]
#corr = tf.zeros([C_pred, C_true], tf.int32)
#corr[prod // 2**24, prod % 2**24] = count
corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1),
count, (C_pred, C_true))
corr = tf.cast(corr, tf.float32)
# [Cpred, Ctrue]
sgv = tf.linalg.svd(corr, compute_uv=False)
results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr))
return 1.0 - tf.reduce_mean(tf.stack(results), 0)
# c_true = tf.reshape(c_true, (n_batch, -1))
# c_pred = tf.reshape(c_pred, (n_batch, -1))
# # [B, H*W]
# n_true = tf.math.reduce_max(c_true, axis=1)
# n_pred = tf.math.reduce_max(c_pred, axis=1)
# # [B]
# diff = tf.cast(n_true - n_pred, tf.float32)
# return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1)
metric.__name__ = 'nCC'
return metric
@ -611,9 +659,9 @@ def run(_config,
else:
loss = get_metric('categorical_crossentropy')
if add_ncc_loss:
loss = metrics_superposition(loss, num_connected_components_regression(0.1),
loss = metrics_superposition(loss, num_connected_components_regression(n_classes - 1),
weights=[1 - add_ncc_loss, add_ncc_loss])
metrics.append(num_connected_components_regression(0.1))
metrics.append(num_connected_components_regression(n_classes - 1))
metrics.append(MeanIoU(n_classes,
name='iou',
ignore_class=0,