training: improve nCC metric/loss - measure localized congruence…

- instead of just comparing the number of connected components,
  calculate the GT/pred label incidence matrix and retrieve the
  share of singular values (i.e. nearly diagonal under reordering)
  over total counts as similarity score
- also, suppress artificial class in that
This commit is contained in:
Robert Sachunsky 2026-02-28 19:44:10 +01:00
parent 7e06ab2c8c
commit 361d40c064

View file

@ -83,7 +83,7 @@ def configuration():
except: except:
print("no GPU device available", file=sys.stderr) print("no GPU device available", file=sys.stderr)
def num_connected_components_regression(alpha: float): def num_connected_components_regression(artificial=0):
""" """
metric/loss function capturing the separability of segmentation maps metric/loss function capturing the separability of segmentation maps
@ -92,29 +92,77 @@ def num_connected_components_regression(alpha: float):
2. the connected components (i.e. the instance label map) 2. the connected components (i.e. the instance label map)
3. the max() (i.e. the highest label = nr of components) 3. the max() (i.e. the highest label = nr of components)
Then calculates a regression formula between those two targets: The original idea was to then calculate a regression formula
- overall mean squared (to incentivise exact fit) between those two targets. But it is insufficient to just
- additive component (to incentivise more over less segments; approximate the same number of components, for they might be
this prevents neighbours of spilling into each other; completely different (true components being merged, predicted
oversegmentation is usually not as bad as undersegmentation) components splitting others). We really want to capture the
correspondence between those labels, which is localised.
For that we now calculate the label pairs and their counts.
Looking at the M,N incidence matrix, we want those counts
to be distributed orthogonally (ideally). So we compute a
singular value decomposition and compare the sum total of
singular values to the sum total of all label counts. The
rate of the two determines a measure of congruence.
Moreover, for the case of artificial boundary segments around
regions, optionally introduced by the training extractor to
represent segment identity in the loss (and removed at runtime):
Reduce this class to background as well.
""" """
def metric(y_true, y_pred): def metric(y_true, y_pred):
if artificial:
# convert artificial border class to background
y_true = y_true[:, :, :, :artificial]
y_pred = y_pred[:, :, :, :artificial]
# [B, H, W, C] # [B, H, W, C]
l_true = tf.math.argmax(y_true, axis=-1) l_true = tf.math.argmax(y_true, axis=-1)
l_pred = tf.math.argmax(y_pred, axis=-1) l_pred = tf.math.argmax(y_pred, axis=-1)
# [B, H, W] # [B, H, W]
c_true = connected_components(l_true) c_true = tf.cast(connected_components(l_true), tf.int64)
c_pred = connected_components(l_pred) c_pred = tf.cast(connected_components(l_pred), tf.int64)
# [B, H, W] # [B, H, W]
n_batch = tf.shape(y_true)[0] #n_batch = tf.shape(y_true)[0]
c_true = tf.reshape(c_true, (n_batch, -1)) n_batch = y_true.shape[0]
c_pred = tf.reshape(c_pred, (n_batch, -1)) C_true = tf.math.reduce_max(c_true, (1, 2)) + 1
# [B, H*W] C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1
n_true = tf.math.reduce_max(c_true, axis=1) MODULUS = tf.constant(2**22, tf.int64)
n_pred = tf.math.reduce_max(c_pred, axis=1) tf.debugging.assert_less(C_true, MODULUS,
# [B] message="cannot compare segments: too many connected components in GT")
diff = tf.cast(n_true - n_pred, tf.float32) tf.debugging.assert_less(C_pred, MODULUS,
return tf.reduce_mean(tf.math.sqrt(tf.math.square(diff) + alpha * diff), axis=-1) message="cannot compare segments: too many connected components in prediction")
c_comb = MODULUS * c_pred + c_true
tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64),
message="overflow pairing components")
# [B, H, W]
# tf.unique does not support batch dim, so...
results = []
for c_comb, C_true, C_pred in zip(
tf.unstack(c_comb, num=n_batch),
tf.unstack(C_true, num=n_batch),
tf.unstack(C_pred, num=n_batch),
):
prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,)))
#tf.print(n_batch, tf.shape(prod), C_true, C_true)
# [L]
#corr = tf.zeros([C_pred, C_true], tf.int32)
#corr[prod // 2**24, prod % 2**24] = count
corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1),
count, (C_pred, C_true))
corr = tf.cast(corr, tf.float32)
# [Cpred, Ctrue]
sgv = tf.linalg.svd(corr, compute_uv=False)
results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr))
return 1.0 - tf.reduce_mean(tf.stack(results), 0)
# c_true = tf.reshape(c_true, (n_batch, -1))
# c_pred = tf.reshape(c_pred, (n_batch, -1))
# # [B, H*W]
# n_true = tf.math.reduce_max(c_true, axis=1)
# n_pred = tf.math.reduce_max(c_pred, axis=1)
# # [B]
# diff = tf.cast(n_true - n_pred, tf.float32)
# return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1)
metric.__name__ = 'nCC' metric.__name__ = 'nCC'
return metric return metric
@ -611,9 +659,9 @@ def run(_config,
else: else:
loss = get_metric('categorical_crossentropy') loss = get_metric('categorical_crossentropy')
if add_ncc_loss: if add_ncc_loss:
loss = metrics_superposition(loss, num_connected_components_regression(0.1), loss = metrics_superposition(loss, num_connected_components_regression(n_classes - 1),
weights=[1 - add_ncc_loss, add_ncc_loss]) weights=[1 - add_ncc_loss, add_ncc_loss])
metrics.append(num_connected_components_regression(0.1)) metrics.append(num_connected_components_regression(n_classes - 1))
metrics.append(MeanIoU(n_classes, metrics.append(MeanIoU(n_classes,
name='iou', name='iou',
ignore_class=0, ignore_class=0,