mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training: add metric for (same) number of connected components
(in trying to capture region instance separability)
This commit is contained in:
parent
18607e0f48
commit
abf111de76
2 changed files with 45 additions and 5 deletions
|
|
@ -14,6 +14,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||||
from tensorflow.keras.layers import StringLookup
|
from tensorflow.keras.layers import StringLookup
|
||||||
from tensorflow.keras.utils import image_dataset_from_directory
|
from tensorflow.keras.utils import image_dataset_from_directory
|
||||||
from tensorflow.keras.backend import one_hot
|
from tensorflow.keras.backend import one_hot
|
||||||
|
from tensorflow_addons.image import connected_components
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
from sacred.config import create_captured_function
|
from sacred.config import create_captured_function
|
||||||
|
|
||||||
|
|
@ -74,6 +75,42 @@ def configuration():
|
||||||
except:
|
except:
|
||||||
print("no GPU device available", file=sys.stderr)
|
print("no GPU device available", file=sys.stderr)
|
||||||
|
|
||||||
|
def num_connected_components_regression(alpha: float):
|
||||||
|
"""
|
||||||
|
metric/loss function capturing the separability of segmentation maps
|
||||||
|
|
||||||
|
For both sides (true and predicted, resp.), computes
|
||||||
|
1. the argmax() of class-wise softmax input (i.e. the segmentation map)
|
||||||
|
2. the connected components (i.e. the instance label map)
|
||||||
|
3. the max() (i.e. the highest label = nr of components)
|
||||||
|
|
||||||
|
Then calculates a regression formula between those two targets:
|
||||||
|
- overall mean squared (to incentivise exact fit)
|
||||||
|
- additive component (to incentivise more over less segments;
|
||||||
|
this prevents neighbours of spilling into each other;
|
||||||
|
oversegmentation is usually not as bad as undersegmentation)
|
||||||
|
"""
|
||||||
|
def metric(y_true, y_pred):
|
||||||
|
# [B, H, W, C]
|
||||||
|
l_true = tf.math.argmax(y_true, axis=-1)
|
||||||
|
l_pred = tf.math.argmax(y_pred, axis=-1)
|
||||||
|
# [B, H, W]
|
||||||
|
c_true = connected_components(l_true)
|
||||||
|
c_pred = connected_components(l_pred)
|
||||||
|
# [B, H, W]
|
||||||
|
n_batch = tf.shape(y_true)[0]
|
||||||
|
c_true = tf.reshape(c_true, (n_batch, -1))
|
||||||
|
c_pred = tf.reshape(c_pred, (n_batch, -1))
|
||||||
|
# [B, H*W]
|
||||||
|
n_true = tf.math.reduce_max(c_true, axis=1)
|
||||||
|
n_pred = tf.math.reduce_max(c_pred, axis=1)
|
||||||
|
# [B]
|
||||||
|
diff = tf.cast(n_true - n_pred, tf.float32)
|
||||||
|
return tf.reduce_mean(tf.math.sqrt(tf.math.square(diff) + alpha * diff), axis=-1)
|
||||||
|
|
||||||
|
metric.__name__ = 'nCC'
|
||||||
|
return metric
|
||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
@ -502,7 +539,9 @@ def run(_config,
|
||||||
loss = 'mean_squared_error'
|
loss = 'mean_squared_error'
|
||||||
model.compile(loss=loss,
|
model.compile(loss=loss,
|
||||||
optimizer=Adam(learning_rate=learning_rate),
|
optimizer=Adam(learning_rate=learning_rate),
|
||||||
metrics=['accuracy', MeanIoU(n_classes,
|
metrics=['accuracy',
|
||||||
|
num_connected_components_regression(0.1),
|
||||||
|
MeanIoU(n_classes,
|
||||||
name='iou',
|
name='iou',
|
||||||
ignore_class=0,
|
ignore_class=0,
|
||||||
sparse_y_true=False,
|
sparse_y_true=False,
|
||||||
|
|
|
||||||
|
|
@ -4,3 +4,4 @@ numpy
|
||||||
tqdm
|
tqdm
|
||||||
imutils
|
imutils
|
||||||
scipy
|
scipy
|
||||||
|
tensorflow-addons # for connected_components
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue