From 361d40c064d4201a3ecefccab00cf08ee95e1013 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sat, 28 Feb 2026 19:44:10 +0100 Subject: [PATCH] =?UTF-8?q?training:=20improve=20nCC=20metric/loss=20-=20m?= =?UTF-8?q?easure=20localized=20congruence=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - instead of just comparing the number of connected components, calculate the GT/pred label incidence matrix and retrieve the share of singular values (i.e. nearly diagonal under reordering) over total counts as similarity score - also, suppress artificial class in that --- src/eynollah/training/train.py | 86 ++++++++++++++++++++++++++-------- 1 file changed, 67 insertions(+), 19 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index a12b9c7..efaa96e 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -83,7 +83,7 @@ def configuration(): except: print("no GPU device available", file=sys.stderr) -def num_connected_components_regression(alpha: float): +def num_connected_components_regression(artificial=0): """ metric/loss function capturing the separability of segmentation maps @@ -92,29 +92,77 @@ def num_connected_components_regression(alpha: float): 2. the connected components (i.e. the instance label map) 3. the max() (i.e. the highest label = nr of components) - Then calculates a regression formula between those two targets: - - overall mean squared (to incentivise exact fit) - - additive component (to incentivise more over less segments; - this prevents neighbours of spilling into each other; - oversegmentation is usually not as bad as undersegmentation) + The original idea was to then calculate a regression formula + between those two targets. But it is insufficient to just + approximate the same number of components, for they might be + completely different (true components being merged, predicted + components splitting others). We really want to capture the + correspondence between those labels, which is localised. + + For that we now calculate the label pairs and their counts. + Looking at the M,N incidence matrix, we want those counts + to be distributed orthogonally (ideally). So we compute a + singular value decomposition and compare the sum total of + singular values to the sum total of all label counts. The + rate of the two determines a measure of congruence. + + Moreover, for the case of artificial boundary segments around + regions, optionally introduced by the training extractor to + represent segment identity in the loss (and removed at runtime): + Reduce this class to background as well. """ def metric(y_true, y_pred): + if artificial: + # convert artificial border class to background + y_true = y_true[:, :, :, :artificial] + y_pred = y_pred[:, :, :, :artificial] # [B, H, W, C] l_true = tf.math.argmax(y_true, axis=-1) l_pred = tf.math.argmax(y_pred, axis=-1) # [B, H, W] - c_true = connected_components(l_true) - c_pred = connected_components(l_pred) + c_true = tf.cast(connected_components(l_true), tf.int64) + c_pred = tf.cast(connected_components(l_pred), tf.int64) # [B, H, W] - n_batch = tf.shape(y_true)[0] - c_true = tf.reshape(c_true, (n_batch, -1)) - c_pred = tf.reshape(c_pred, (n_batch, -1)) - # [B, H*W] - n_true = tf.math.reduce_max(c_true, axis=1) - n_pred = tf.math.reduce_max(c_pred, axis=1) - # [B] - diff = tf.cast(n_true - n_pred, tf.float32) - return tf.reduce_mean(tf.math.sqrt(tf.math.square(diff) + alpha * diff), axis=-1) + #n_batch = tf.shape(y_true)[0] + n_batch = y_true.shape[0] + C_true = tf.math.reduce_max(c_true, (1, 2)) + 1 + C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1 + MODULUS = tf.constant(2**22, tf.int64) + tf.debugging.assert_less(C_true, MODULUS, + message="cannot compare segments: too many connected components in GT") + tf.debugging.assert_less(C_pred, MODULUS, + message="cannot compare segments: too many connected components in prediction") + c_comb = MODULUS * c_pred + c_true + tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64), + message="overflow pairing components") + # [B, H, W] + # tf.unique does not support batch dim, so... + results = [] + for c_comb, C_true, C_pred in zip( + tf.unstack(c_comb, num=n_batch), + tf.unstack(C_true, num=n_batch), + tf.unstack(C_pred, num=n_batch), + ): + prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,))) + #tf.print(n_batch, tf.shape(prod), C_true, C_true) + # [L] + #corr = tf.zeros([C_pred, C_true], tf.int32) + #corr[prod // 2**24, prod % 2**24] = count + corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1), + count, (C_pred, C_true)) + corr = tf.cast(corr, tf.float32) + # [Cpred, Ctrue] + sgv = tf.linalg.svd(corr, compute_uv=False) + results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr)) + return 1.0 - tf.reduce_mean(tf.stack(results), 0) + # c_true = tf.reshape(c_true, (n_batch, -1)) + # c_pred = tf.reshape(c_pred, (n_batch, -1)) + # # [B, H*W] + # n_true = tf.math.reduce_max(c_true, axis=1) + # n_pred = tf.math.reduce_max(c_pred, axis=1) + # # [B] + # diff = tf.cast(n_true - n_pred, tf.float32) + # return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1) metric.__name__ = 'nCC' return metric @@ -611,9 +659,9 @@ def run(_config, else: loss = get_metric('categorical_crossentropy') if add_ncc_loss: - loss = metrics_superposition(loss, num_connected_components_regression(0.1), + loss = metrics_superposition(loss, num_connected_components_regression(n_classes - 1), weights=[1 - add_ncc_loss, add_ncc_loss]) - metrics.append(num_connected_components_regression(0.1)) + metrics.append(num_connected_components_regression(n_classes - 1)) metrics.append(MeanIoU(n_classes, name='iou', ignore_class=0,