diff --git a/src/eynollah/training/metrics.py b/src/eynollah/training/metrics.py index a8f47d7..56dc732 100644 --- a/src/eynollah/training/metrics.py +++ b/src/eynollah/training/metrics.py @@ -1,5 +1,10 @@ -from tensorflow.keras import backend as K +import os + +os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf +from tensorflow.keras import backend as K +from tensorflow.keras.metrics import Metric +from tensorflow.keras.initializers import Zeros import numpy as np @@ -361,3 +366,47 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100): sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) jac = (intersection + smooth) / (sum_ - intersection + smooth) return (1 - jac) * smooth + + +class ConfusionMatrix(Metric): + def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32): + super().__init__(name=name, dtype=dtype) + assert nlabels is not None + self._nlabels = nlabels + self._shape = (self._nlabels, self._nlabels) + self._matrix = self.add_weight(name, shape=self._shape, + initializer=Zeros) + assert nrm in ("all", "true", "pred", "none") + self._nrm = nrm + + def update_state(self, y_true, y_pred, sample_weight=None): + y_pred = tf.math.argmax(y_pred, axis=-1) + y_true = tf.math.argmax(y_true, axis=-1) + + y_pred = tf.reshape(y_pred, shape=(-1,)) + y_true = tf.reshape(y_true, shape=(-1,)) + + y_pred.shape.assert_is_compatible_with(y_true.shape) + confusion = tf.math.confusion_matrix(y_true, y_pred, num_classes=self._nlabels, dtype=self._dtype) + + return self._matrix.assign_add(confusion) + + def result(self): + """normalize""" + if self._nrm == "all": + denom = tf.math.reduce_sum(self._matrix, axis=(0, 1)) + elif self._nrm == "true": + denom = tf.math.reduce_sum(self._matrix, axis=1, keepdims=True) + elif self._nrm == "pred": + denom = tf.math.reduce_sum(self._matrix, axis=0, keepdims=True) + else: + denom = tf.constant(1.0) + return tf.math.divide_no_nan(self._matrix, denom) + + def reset_state(self): + for v in self.variables: + v.assign(tf.zeros(shape=self._shape)) + + def get_config(self): + return dict(nlabels=self._nlabels, + **super().get_config()) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 74a7a90..0c624c3 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -1,5 +1,6 @@ import os import sys +import io import json from tqdm import tqdm @@ -21,10 +22,12 @@ from sacred.config import create_captured_function import numpy as np import cv2 +from matplotlib import pyplot as plt # for plot_confusion_matrix from .metrics import ( soft_dice_loss, - weighted_categorical_crossentropy + weighted_categorical_crossentropy, + ConfusionMatrix, ) from .models import ( PatchEncoder, @@ -151,6 +154,45 @@ def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor: weighted = image * 0.9 + layout * 0.1 return tf.cast(weighted, tf.uint8) +def plot_confusion_matrix(cm, name="Confusion Matrix"): + """ + Plot the confusion matrix with matplotlib and tensorflow + """ + size = cm.shape[0] + fig, ax = plt.subplots() + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + ax.set(xticks=np.arange(cm.shape[1]), + yticks=np.arange(cm.shape[0]), + xlim=[-0.5, cm.shape[1] - 0.5], + ylim=[-0.5, cm.shape[0] - 0.5], + #xticklabels=labels, + #yticklabels=labels, + title=name, + ylabel='True class', + xlabel='Predicted class') + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", + rotation_mode="anchor") + # Loop over data dimensions and create text annotations. + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], ".2f"), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + fig.tight_layout() + # convert to PNG + buf = io.BytesIO() + fig.savefig(buf, format='png') + plt.close(fig) + buf.seek(0) + # Convert PNG buffer to TF image + image = tf.image.decode_png(buf.getvalue(), channels=4) + # Add the batch dimension + image = tf.expand_dims(image, 0) + return image + # plot predictions on train and test set during every epoch class TensorBoardPlotter(TensorBoard): def __init__(self, *args, **kwargs): @@ -185,6 +227,36 @@ class TensorBoardPlotter(TensorBoard): # used to be family kwarg for tf.summary.image name prefix with tf.name_scope(family): tf.summary.image(mode, images, step=step, max_outputs=len(images)) + def on_train_batch_end(self, batch, logs=None): + if logs is not None: + logs = dict(logs) + # cannot be logged as scalar: + logs.pop('confusion_matrix', None) + super().on_train_batch_end(batch, logs) + def on_test_end(self, logs=None): + if logs is not None: + logs = dict(logs) + # cannot be logged as scalar: + logs.pop('confusion_matrix', None) + super().on_test_end(logs) + def _log_epoch_metrics(self, epoch, logs): + if not logs: + return + logs = dict(logs) + # cannot be logged as scalar: + train_matrix = logs.pop('confusion_matrix', None) + val_matrix = logs.pop('val_confusion_matrix', None) + super()._log_epoch_metrics(epoch, logs) + # now plot confusion_matrix + with tf.summary.record_if(True): + if train_matrix is not None: + train_image = plot_confusion_matrix(train_matrix) + with self._train_writer.as_default(): + tf.summary.image("confusion_matrix", train_image, step=epoch) + if val_matrix is not None: + val_image = plot_confusion_matrix(val_matrix) + with self._val_writer.as_default(): + tf.summary.image("confusion_matrix", val_image, step=epoch) def get_dirs_or_files(input_data): image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') @@ -523,6 +595,7 @@ def run(_config, #if you want to see the model structure just uncomment model summary. #model.summary() + metrics = ['categorical_accuracy'] if task in ["segmentation", "binarization"]: if is_loss_soft_dice: loss = soft_dice_loss @@ -530,17 +603,18 @@ def run(_config, loss = weighted_categorical_crossentropy(weights) else: loss = 'categorical_crossentropy' + metrics.append(num_connected_components_regression(0.1)) + metrics.append(MeanIoU(n_classes, + name='iou', + ignore_class=0, + sparse_y_true=False, + sparse_y_pred=False)) + metrics.append(ConfusionMatrix(n_classes)) else: # task == "enhancement" loss = 'mean_squared_error' model.compile(loss=loss, optimizer=Adam(learning_rate=learning_rate), - metrics=['accuracy', - num_connected_components_regression(0.1), - MeanIoU(n_classes, - name='iou', - ignore_class=0, - sparse_y_true=False, - sparse_y_pred=False)]) + metrics=metrics) def _to_cv2float(img): # rgb→bgr and uint8→float, as expected by Eynollah models