diff --git a/src/eynollah/training/metrics.py b/src/eynollah/training/metrics.py index 60ac421..caa0e65 100644 --- a/src/eynollah/training/metrics.py +++ b/src/eynollah/training/metrics.py @@ -5,6 +5,7 @@ import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get from tensorflow.keras.initializers import Zeros +from tensorflow_addons.image import connected_components import numpy as np @@ -439,3 +440,87 @@ class ConfusionMatrix(Metric): def get_config(self): return dict(nlabels=self._nlabels, **super().get_config()) + +def connected_components_loss(artificial=0): + """ + metric/loss function capturing the separability of segmentation maps + + For both sides (true and predicted, resp.), computes + 1. the argmax() of class-wise softmax input (i.e. the segmentation map) + 2. the connected components (i.e. the instance label map) + 3. the max() (i.e. the highest label = nr of components) + + The original idea was to then calculate a regression formula + between those two targets. But it is insufficient to just + approximate the same number of components, for they might be + completely different (true components being merged, predicted + components splitting others). We really want to capture the + correspondence between those labels, which is localised. + + For that we now calculate the label pairs and their counts. + Looking at the M,N incidence matrix, we want those counts + to be distributed orthogonally (ideally). So we compute a + singular value decomposition and compare the sum total of + singular values to the sum total of all label counts. The + rate of the two determines a measure of congruence. + + Moreover, for the case of artificial boundary segments around + regions, optionally introduced by the training extractor to + represent segment identity in the loss (and removed at runtime): + Reduce this class to background as well. + """ + def metric(y_true, y_pred): + if artificial: + # convert artificial border class to background + y_true = y_true[:, :, :, :artificial] + y_pred = y_pred[:, :, :, :artificial] + # [B, H, W, C] + l_true = tf.math.argmax(y_true, axis=-1) + l_pred = tf.math.argmax(y_pred, axis=-1) + # [B, H, W] + c_true = tf.cast(connected_components(l_true), tf.int64) + c_pred = tf.cast(connected_components(l_pred), tf.int64) + # [B, H, W] + n_batch = y_true.shape[0] + C_true = tf.math.reduce_max(c_true, (1, 2)) + 1 + C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1 + MODULUS = tf.constant(2**22, tf.int64) + tf.debugging.assert_less(C_true, MODULUS, + message="cannot compare segments: too many connected components in GT") + tf.debugging.assert_less(C_pred, MODULUS, + message="cannot compare segments: too many connected components in prediction") + c_comb = MODULUS * c_pred + c_true + tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64), + message="overflow pairing components") + # [B, H, W] + # tf.unique does not support batch dim, so... + results = [] + for c_comb, C_true, C_pred in zip( + tf.unstack(c_comb, num=n_batch), + tf.unstack(C_true, num=n_batch), + tf.unstack(C_pred, num=n_batch), + ): + prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,))) + # [L] + #corr = tf.zeros([C_pred, C_true], tf.int32) + #corr[prod // 2**24, prod % 2**24] = count + corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1), + count, (C_pred, C_true)) + corr = tf.cast(corr, tf.float32) + # [Cpred, Ctrue] + sgv = tf.linalg.svd(corr, compute_uv=False) + results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr)) + return 1.0 - tf.reduce_mean(tf.stack(results), 0) + # c_true = tf.reshape(c_true, (n_batch, -1)) + # c_pred = tf.reshape(c_pred, (n_batch, -1)) + # # [B, H*W] + # n_true = tf.math.reduce_max(c_true, axis=1) + # n_pred = tf.math.reduce_max(c_pred, axis=1) + # # [B] + # diff = tf.cast(n_true - n_pred, tf.float32) + # return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1) + + metric.__name__ = 'nCC' + metric._direction = 'down' + return metric + diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index efaa96e..f06c35b 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -16,7 +16,6 @@ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStoppi from tensorflow.keras.layers import StringLookup from tensorflow.keras.utils import image_dataset_from_directory from tensorflow.keras.backend import one_hot -from tensorflow_addons.image import connected_components from sacred import Experiment from sacred.config import create_captured_function @@ -30,6 +29,7 @@ from .metrics import ( get as get_metric, metrics_superposition, ConfusionMatrix, + connected_components_loss, ) from .models import ( PatchEncoder, @@ -83,90 +83,6 @@ def configuration(): except: print("no GPU device available", file=sys.stderr) -def num_connected_components_regression(artificial=0): - """ - metric/loss function capturing the separability of segmentation maps - - For both sides (true and predicted, resp.), computes - 1. the argmax() of class-wise softmax input (i.e. the segmentation map) - 2. the connected components (i.e. the instance label map) - 3. the max() (i.e. the highest label = nr of components) - - The original idea was to then calculate a regression formula - between those two targets. But it is insufficient to just - approximate the same number of components, for they might be - completely different (true components being merged, predicted - components splitting others). We really want to capture the - correspondence between those labels, which is localised. - - For that we now calculate the label pairs and their counts. - Looking at the M,N incidence matrix, we want those counts - to be distributed orthogonally (ideally). So we compute a - singular value decomposition and compare the sum total of - singular values to the sum total of all label counts. The - rate of the two determines a measure of congruence. - - Moreover, for the case of artificial boundary segments around - regions, optionally introduced by the training extractor to - represent segment identity in the loss (and removed at runtime): - Reduce this class to background as well. - """ - def metric(y_true, y_pred): - if artificial: - # convert artificial border class to background - y_true = y_true[:, :, :, :artificial] - y_pred = y_pred[:, :, :, :artificial] - # [B, H, W, C] - l_true = tf.math.argmax(y_true, axis=-1) - l_pred = tf.math.argmax(y_pred, axis=-1) - # [B, H, W] - c_true = tf.cast(connected_components(l_true), tf.int64) - c_pred = tf.cast(connected_components(l_pred), tf.int64) - # [B, H, W] - #n_batch = tf.shape(y_true)[0] - n_batch = y_true.shape[0] - C_true = tf.math.reduce_max(c_true, (1, 2)) + 1 - C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1 - MODULUS = tf.constant(2**22, tf.int64) - tf.debugging.assert_less(C_true, MODULUS, - message="cannot compare segments: too many connected components in GT") - tf.debugging.assert_less(C_pred, MODULUS, - message="cannot compare segments: too many connected components in prediction") - c_comb = MODULUS * c_pred + c_true - tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64), - message="overflow pairing components") - # [B, H, W] - # tf.unique does not support batch dim, so... - results = [] - for c_comb, C_true, C_pred in zip( - tf.unstack(c_comb, num=n_batch), - tf.unstack(C_true, num=n_batch), - tf.unstack(C_pred, num=n_batch), - ): - prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,))) - #tf.print(n_batch, tf.shape(prod), C_true, C_true) - # [L] - #corr = tf.zeros([C_pred, C_true], tf.int32) - #corr[prod // 2**24, prod % 2**24] = count - corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1), - count, (C_pred, C_true)) - corr = tf.cast(corr, tf.float32) - # [Cpred, Ctrue] - sgv = tf.linalg.svd(corr, compute_uv=False) - results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr)) - return 1.0 - tf.reduce_mean(tf.stack(results), 0) - # c_true = tf.reshape(c_true, (n_batch, -1)) - # c_pred = tf.reshape(c_pred, (n_batch, -1)) - # # [B, H*W] - # n_true = tf.math.reduce_max(c_true, axis=1) - # n_pred = tf.math.reduce_max(c_pred, axis=1) - # # [B] - # diff = tf.cast(n_true - n_pred, tf.float32) - # return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1) - - metric.__name__ = 'nCC' - return metric - def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor: """ Implements training.inference.SBBPredict.visualize_model_output for TF @@ -659,9 +575,9 @@ def run(_config, else: loss = get_metric('categorical_crossentropy') if add_ncc_loss: - loss = metrics_superposition(loss, num_connected_components_regression(n_classes - 1), + loss = metrics_superposition(loss, connected_components_loss(n_classes - 1), weights=[1 - add_ncc_loss, add_ncc_loss]) - metrics.append(num_connected_components_regression(n_classes - 1)) + metrics.append(connected_components_loss(n_classes - 1)) metrics.append(MeanIoU(n_classes, name='iou', ignore_class=0,