mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: move nCC metric/loss to .metrics and rename…
- `num_connected_components_regression` → `connected_components_loss` - move from training.train to training.metrics
This commit is contained in:
parent
361d40c064
commit
e47653f684
2 changed files with 88 additions and 87 deletions
|
|
@ -5,6 +5,7 @@ import tensorflow as tf
|
||||||
from tensorflow.keras import backend as K
|
from tensorflow.keras import backend as K
|
||||||
from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get
|
from tensorflow.keras.metrics import Metric, MeanMetricWrapper, get
|
||||||
from tensorflow.keras.initializers import Zeros
|
from tensorflow.keras.initializers import Zeros
|
||||||
|
from tensorflow_addons.image import connected_components
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -439,3 +440,87 @@ class ConfusionMatrix(Metric):
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return dict(nlabels=self._nlabels,
|
return dict(nlabels=self._nlabels,
|
||||||
**super().get_config())
|
**super().get_config())
|
||||||
|
|
||||||
|
def connected_components_loss(artificial=0):
|
||||||
|
"""
|
||||||
|
metric/loss function capturing the separability of segmentation maps
|
||||||
|
|
||||||
|
For both sides (true and predicted, resp.), computes
|
||||||
|
1. the argmax() of class-wise softmax input (i.e. the segmentation map)
|
||||||
|
2. the connected components (i.e. the instance label map)
|
||||||
|
3. the max() (i.e. the highest label = nr of components)
|
||||||
|
|
||||||
|
The original idea was to then calculate a regression formula
|
||||||
|
between those two targets. But it is insufficient to just
|
||||||
|
approximate the same number of components, for they might be
|
||||||
|
completely different (true components being merged, predicted
|
||||||
|
components splitting others). We really want to capture the
|
||||||
|
correspondence between those labels, which is localised.
|
||||||
|
|
||||||
|
For that we now calculate the label pairs and their counts.
|
||||||
|
Looking at the M,N incidence matrix, we want those counts
|
||||||
|
to be distributed orthogonally (ideally). So we compute a
|
||||||
|
singular value decomposition and compare the sum total of
|
||||||
|
singular values to the sum total of all label counts. The
|
||||||
|
rate of the two determines a measure of congruence.
|
||||||
|
|
||||||
|
Moreover, for the case of artificial boundary segments around
|
||||||
|
regions, optionally introduced by the training extractor to
|
||||||
|
represent segment identity in the loss (and removed at runtime):
|
||||||
|
Reduce this class to background as well.
|
||||||
|
"""
|
||||||
|
def metric(y_true, y_pred):
|
||||||
|
if artificial:
|
||||||
|
# convert artificial border class to background
|
||||||
|
y_true = y_true[:, :, :, :artificial]
|
||||||
|
y_pred = y_pred[:, :, :, :artificial]
|
||||||
|
# [B, H, W, C]
|
||||||
|
l_true = tf.math.argmax(y_true, axis=-1)
|
||||||
|
l_pred = tf.math.argmax(y_pred, axis=-1)
|
||||||
|
# [B, H, W]
|
||||||
|
c_true = tf.cast(connected_components(l_true), tf.int64)
|
||||||
|
c_pred = tf.cast(connected_components(l_pred), tf.int64)
|
||||||
|
# [B, H, W]
|
||||||
|
n_batch = y_true.shape[0]
|
||||||
|
C_true = tf.math.reduce_max(c_true, (1, 2)) + 1
|
||||||
|
C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1
|
||||||
|
MODULUS = tf.constant(2**22, tf.int64)
|
||||||
|
tf.debugging.assert_less(C_true, MODULUS,
|
||||||
|
message="cannot compare segments: too many connected components in GT")
|
||||||
|
tf.debugging.assert_less(C_pred, MODULUS,
|
||||||
|
message="cannot compare segments: too many connected components in prediction")
|
||||||
|
c_comb = MODULUS * c_pred + c_true
|
||||||
|
tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64),
|
||||||
|
message="overflow pairing components")
|
||||||
|
# [B, H, W]
|
||||||
|
# tf.unique does not support batch dim, so...
|
||||||
|
results = []
|
||||||
|
for c_comb, C_true, C_pred in zip(
|
||||||
|
tf.unstack(c_comb, num=n_batch),
|
||||||
|
tf.unstack(C_true, num=n_batch),
|
||||||
|
tf.unstack(C_pred, num=n_batch),
|
||||||
|
):
|
||||||
|
prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,)))
|
||||||
|
# [L]
|
||||||
|
#corr = tf.zeros([C_pred, C_true], tf.int32)
|
||||||
|
#corr[prod // 2**24, prod % 2**24] = count
|
||||||
|
corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1),
|
||||||
|
count, (C_pred, C_true))
|
||||||
|
corr = tf.cast(corr, tf.float32)
|
||||||
|
# [Cpred, Ctrue]
|
||||||
|
sgv = tf.linalg.svd(corr, compute_uv=False)
|
||||||
|
results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr))
|
||||||
|
return 1.0 - tf.reduce_mean(tf.stack(results), 0)
|
||||||
|
# c_true = tf.reshape(c_true, (n_batch, -1))
|
||||||
|
# c_pred = tf.reshape(c_pred, (n_batch, -1))
|
||||||
|
# # [B, H*W]
|
||||||
|
# n_true = tf.math.reduce_max(c_true, axis=1)
|
||||||
|
# n_pred = tf.math.reduce_max(c_pred, axis=1)
|
||||||
|
# # [B]
|
||||||
|
# diff = tf.cast(n_true - n_pred, tf.float32)
|
||||||
|
# return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1)
|
||||||
|
|
||||||
|
metric.__name__ = 'nCC'
|
||||||
|
metric._direction = 'down'
|
||||||
|
return metric
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStoppi
|
||||||
from tensorflow.keras.layers import StringLookup
|
from tensorflow.keras.layers import StringLookup
|
||||||
from tensorflow.keras.utils import image_dataset_from_directory
|
from tensorflow.keras.utils import image_dataset_from_directory
|
||||||
from tensorflow.keras.backend import one_hot
|
from tensorflow.keras.backend import one_hot
|
||||||
from tensorflow_addons.image import connected_components
|
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
from sacred.config import create_captured_function
|
from sacred.config import create_captured_function
|
||||||
|
|
||||||
|
|
@ -30,6 +29,7 @@ from .metrics import (
|
||||||
get as get_metric,
|
get as get_metric,
|
||||||
metrics_superposition,
|
metrics_superposition,
|
||||||
ConfusionMatrix,
|
ConfusionMatrix,
|
||||||
|
connected_components_loss,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
PatchEncoder,
|
PatchEncoder,
|
||||||
|
|
@ -83,90 +83,6 @@ def configuration():
|
||||||
except:
|
except:
|
||||||
print("no GPU device available", file=sys.stderr)
|
print("no GPU device available", file=sys.stderr)
|
||||||
|
|
||||||
def num_connected_components_regression(artificial=0):
|
|
||||||
"""
|
|
||||||
metric/loss function capturing the separability of segmentation maps
|
|
||||||
|
|
||||||
For both sides (true and predicted, resp.), computes
|
|
||||||
1. the argmax() of class-wise softmax input (i.e. the segmentation map)
|
|
||||||
2. the connected components (i.e. the instance label map)
|
|
||||||
3. the max() (i.e. the highest label = nr of components)
|
|
||||||
|
|
||||||
The original idea was to then calculate a regression formula
|
|
||||||
between those two targets. But it is insufficient to just
|
|
||||||
approximate the same number of components, for they might be
|
|
||||||
completely different (true components being merged, predicted
|
|
||||||
components splitting others). We really want to capture the
|
|
||||||
correspondence between those labels, which is localised.
|
|
||||||
|
|
||||||
For that we now calculate the label pairs and their counts.
|
|
||||||
Looking at the M,N incidence matrix, we want those counts
|
|
||||||
to be distributed orthogonally (ideally). So we compute a
|
|
||||||
singular value decomposition and compare the sum total of
|
|
||||||
singular values to the sum total of all label counts. The
|
|
||||||
rate of the two determines a measure of congruence.
|
|
||||||
|
|
||||||
Moreover, for the case of artificial boundary segments around
|
|
||||||
regions, optionally introduced by the training extractor to
|
|
||||||
represent segment identity in the loss (and removed at runtime):
|
|
||||||
Reduce this class to background as well.
|
|
||||||
"""
|
|
||||||
def metric(y_true, y_pred):
|
|
||||||
if artificial:
|
|
||||||
# convert artificial border class to background
|
|
||||||
y_true = y_true[:, :, :, :artificial]
|
|
||||||
y_pred = y_pred[:, :, :, :artificial]
|
|
||||||
# [B, H, W, C]
|
|
||||||
l_true = tf.math.argmax(y_true, axis=-1)
|
|
||||||
l_pred = tf.math.argmax(y_pred, axis=-1)
|
|
||||||
# [B, H, W]
|
|
||||||
c_true = tf.cast(connected_components(l_true), tf.int64)
|
|
||||||
c_pred = tf.cast(connected_components(l_pred), tf.int64)
|
|
||||||
# [B, H, W]
|
|
||||||
#n_batch = tf.shape(y_true)[0]
|
|
||||||
n_batch = y_true.shape[0]
|
|
||||||
C_true = tf.math.reduce_max(c_true, (1, 2)) + 1
|
|
||||||
C_pred = tf.math.reduce_max(c_pred, (1, 2)) + 1
|
|
||||||
MODULUS = tf.constant(2**22, tf.int64)
|
|
||||||
tf.debugging.assert_less(C_true, MODULUS,
|
|
||||||
message="cannot compare segments: too many connected components in GT")
|
|
||||||
tf.debugging.assert_less(C_pred, MODULUS,
|
|
||||||
message="cannot compare segments: too many connected components in prediction")
|
|
||||||
c_comb = MODULUS * c_pred + c_true
|
|
||||||
tf.debugging.assert_greater_equal(c_comb, tf.constant(0, tf.int64),
|
|
||||||
message="overflow pairing components")
|
|
||||||
# [B, H, W]
|
|
||||||
# tf.unique does not support batch dim, so...
|
|
||||||
results = []
|
|
||||||
for c_comb, C_true, C_pred in zip(
|
|
||||||
tf.unstack(c_comb, num=n_batch),
|
|
||||||
tf.unstack(C_true, num=n_batch),
|
|
||||||
tf.unstack(C_pred, num=n_batch),
|
|
||||||
):
|
|
||||||
prod, _, count = tf.unique_with_counts(tf.reshape(c_comb, (-1,)))
|
|
||||||
#tf.print(n_batch, tf.shape(prod), C_true, C_true)
|
|
||||||
# [L]
|
|
||||||
#corr = tf.zeros([C_pred, C_true], tf.int32)
|
|
||||||
#corr[prod // 2**24, prod % 2**24] = count
|
|
||||||
corr = tf.scatter_nd(tf.stack([prod // MODULUS, prod % MODULUS], axis=1),
|
|
||||||
count, (C_pred, C_true))
|
|
||||||
corr = tf.cast(corr, tf.float32)
|
|
||||||
# [Cpred, Ctrue]
|
|
||||||
sgv = tf.linalg.svd(corr, compute_uv=False)
|
|
||||||
results.append(tf.reduce_sum(sgv) / tf.reduce_sum(corr))
|
|
||||||
return 1.0 - tf.reduce_mean(tf.stack(results), 0)
|
|
||||||
# c_true = tf.reshape(c_true, (n_batch, -1))
|
|
||||||
# c_pred = tf.reshape(c_pred, (n_batch, -1))
|
|
||||||
# # [B, H*W]
|
|
||||||
# n_true = tf.math.reduce_max(c_true, axis=1)
|
|
||||||
# n_pred = tf.math.reduce_max(c_pred, axis=1)
|
|
||||||
# # [B]
|
|
||||||
# diff = tf.cast(n_true - n_pred, tf.float32)
|
|
||||||
# return tf.reduce_mean(tf.math.abs(diff) + alpha * diff, axis=-1)
|
|
||||||
|
|
||||||
metric.__name__ = 'nCC'
|
|
||||||
return metric
|
|
||||||
|
|
||||||
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Implements training.inference.SBBPredict.visualize_model_output for TF
|
Implements training.inference.SBBPredict.visualize_model_output for TF
|
||||||
|
|
@ -659,9 +575,9 @@ def run(_config,
|
||||||
else:
|
else:
|
||||||
loss = get_metric('categorical_crossentropy')
|
loss = get_metric('categorical_crossentropy')
|
||||||
if add_ncc_loss:
|
if add_ncc_loss:
|
||||||
loss = metrics_superposition(loss, num_connected_components_regression(n_classes - 1),
|
loss = metrics_superposition(loss, connected_components_loss(n_classes - 1),
|
||||||
weights=[1 - add_ncc_loss, add_ncc_loss])
|
weights=[1 - add_ncc_loss, add_ncc_loss])
|
||||||
metrics.append(num_connected_components_regression(n_classes - 1))
|
metrics.append(connected_components_loss(n_classes - 1))
|
||||||
metrics.append(MeanIoU(n_classes,
|
metrics.append(MeanIoU(n_classes,
|
||||||
name='iou',
|
name='iou',
|
||||||
ignore_class=0,
|
ignore_class=0,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue