diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 5305ee3..233c6a4 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -14,6 +14,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from tensorflow.keras.layers import StringLookup from tensorflow.keras.utils import image_dataset_from_directory from tensorflow.keras.backend import one_hot +from tensorflow_addons.image import connected_components from sacred import Experiment from sacred.config import create_captured_function @@ -74,6 +75,42 @@ def configuration(): except: print("no GPU device available", file=sys.stderr) +def num_connected_components_regression(alpha: float): + """ + metric/loss function capturing the separability of segmentation maps + + For both sides (true and predicted, resp.), computes + 1. the argmax() of class-wise softmax input (i.e. the segmentation map) + 2. the connected components (i.e. the instance label map) + 3. the max() (i.e. the highest label = nr of components) + + Then calculates a regression formula between those two targets: + - overall mean squared (to incentivise exact fit) + - additive component (to incentivise more over less segments; + this prevents neighbours of spilling into each other; + oversegmentation is usually not as bad as undersegmentation) + """ + def metric(y_true, y_pred): + # [B, H, W, C] + l_true = tf.math.argmax(y_true, axis=-1) + l_pred = tf.math.argmax(y_pred, axis=-1) + # [B, H, W] + c_true = connected_components(l_true) + c_pred = connected_components(l_pred) + # [B, H, W] + n_batch = tf.shape(y_true)[0] + c_true = tf.reshape(c_true, (n_batch, -1)) + c_pred = tf.reshape(c_pred, (n_batch, -1)) + # [B, H*W] + n_true = tf.math.reduce_max(c_true, axis=1) + n_pred = tf.math.reduce_max(c_pred, axis=1) + # [B] + diff = tf.cast(n_true - n_pred, tf.float32) + return tf.reduce_mean(tf.math.sqrt(tf.math.square(diff) + alpha * diff), axis=-1) + + metric.__name__ = 'nCC' + return metric + @tf.function def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor: """ @@ -502,11 +539,13 @@ def run(_config, loss = 'mean_squared_error' model.compile(loss=loss, optimizer=Adam(learning_rate=learning_rate), - metrics=['accuracy', MeanIoU(n_classes, - name='iou', - ignore_class=0, - sparse_y_true=False, - sparse_y_pred=False)]) + metrics=['accuracy', + num_connected_components_regression(0.1), + MeanIoU(n_classes, + name='iou', + ignore_class=0, + sparse_y_true=False, + sparse_y_pred=False)]) def _to_cv2float(img): # rgb→bgr and uint8→float, as expected by Eynollah models diff --git a/train/requirements.txt b/train/requirements.txt index 8ad884d..6f23d76 100644 --- a/train/requirements.txt +++ b/train/requirements.txt @@ -4,3 +4,4 @@ numpy tqdm imutils scipy +tensorflow-addons # for connected_components