training: add metric ConfusionMatrix and plot it to TensorBoard

This commit is contained in:
Robert Sachunsky 2026-02-26 13:55:37 +01:00
parent b6d2440ce1
commit 439ca350dd
2 changed files with 132 additions and 9 deletions

View file

@ -1,5 +1,10 @@
from tensorflow.keras import backend as K
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.metrics import Metric
from tensorflow.keras.initializers import Zeros
import numpy as np
@ -361,3 +366,47 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100):
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
jac = (intersection + smooth) / (sum_ - intersection + smooth)
return (1 - jac) * smooth
class ConfusionMatrix(Metric):
def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32):
super().__init__(name=name, dtype=dtype)
assert nlabels is not None
self._nlabels = nlabels
self._shape = (self._nlabels, self._nlabels)
self._matrix = self.add_weight(name, shape=self._shape,
initializer=Zeros)
assert nrm in ("all", "true", "pred", "none")
self._nrm = nrm
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.math.argmax(y_true, axis=-1)
y_pred = tf.reshape(y_pred, shape=(-1,))
y_true = tf.reshape(y_true, shape=(-1,))
y_pred.shape.assert_is_compatible_with(y_true.shape)
confusion = tf.math.confusion_matrix(y_true, y_pred, num_classes=self._nlabels, dtype=self._dtype)
return self._matrix.assign_add(confusion)
def result(self):
"""normalize"""
if self._nrm == "all":
denom = tf.math.reduce_sum(self._matrix, axis=(0, 1))
elif self._nrm == "true":
denom = tf.math.reduce_sum(self._matrix, axis=1, keepdims=True)
elif self._nrm == "pred":
denom = tf.math.reduce_sum(self._matrix, axis=0, keepdims=True)
else:
denom = tf.constant(1.0)
return tf.math.divide_no_nan(self._matrix, denom)
def reset_state(self):
for v in self.variables:
v.assign(tf.zeros(shape=self._shape))
def get_config(self):
return dict(nlabels=self._nlabels,
**super().get_config())

View file

@ -1,5 +1,6 @@
import os
import sys
import io
import json
from tqdm import tqdm
@ -21,10 +22,12 @@ from sacred.config import create_captured_function
import numpy as np
import cv2
from matplotlib import pyplot as plt # for plot_confusion_matrix
from .metrics import (
soft_dice_loss,
weighted_categorical_crossentropy
weighted_categorical_crossentropy,
ConfusionMatrix,
)
from .models import (
PatchEncoder,
@ -151,6 +154,45 @@ def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
weighted = image * 0.9 + layout * 0.1
return tf.cast(weighted, tf.uint8)
def plot_confusion_matrix(cm, name="Confusion Matrix"):
"""
Plot the confusion matrix with matplotlib and tensorflow
"""
size = cm.shape[0]
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xlim=[-0.5, cm.shape[1] - 0.5],
ylim=[-0.5, cm.shape[0] - 0.5],
#xticklabels=labels,
#yticklabels=labels,
title=name,
ylabel='True class',
xlabel='Predicted class')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], ".2f"),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
# convert to PNG
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
# Convert PNG buffer to TF image
image = tf.image.decode_png(buf.getvalue(), channels=4)
# Add the batch dimension
image = tf.expand_dims(image, 0)
return image
# plot predictions on train and test set during every epoch
class TensorBoardPlotter(TensorBoard):
def __init__(self, *args, **kwargs):
@ -185,6 +227,36 @@ class TensorBoardPlotter(TensorBoard):
# used to be family kwarg for tf.summary.image name prefix
with tf.name_scope(family):
tf.summary.image(mode, images, step=step, max_outputs=len(images))
def on_train_batch_end(self, batch, logs=None):
if logs is not None:
logs = dict(logs)
# cannot be logged as scalar:
logs.pop('confusion_matrix', None)
super().on_train_batch_end(batch, logs)
def on_test_end(self, logs=None):
if logs is not None:
logs = dict(logs)
# cannot be logged as scalar:
logs.pop('confusion_matrix', None)
super().on_test_end(logs)
def _log_epoch_metrics(self, epoch, logs):
if not logs:
return
logs = dict(logs)
# cannot be logged as scalar:
train_matrix = logs.pop('confusion_matrix', None)
val_matrix = logs.pop('val_confusion_matrix', None)
super()._log_epoch_metrics(epoch, logs)
# now plot confusion_matrix
with tf.summary.record_if(True):
if train_matrix is not None:
train_image = plot_confusion_matrix(train_matrix)
with self._train_writer.as_default():
tf.summary.image("confusion_matrix", train_image, step=epoch)
if val_matrix is not None:
val_image = plot_confusion_matrix(val_matrix)
with self._val_writer.as_default():
tf.summary.image("confusion_matrix", val_image, step=epoch)
def get_dirs_or_files(input_data):
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
@ -523,6 +595,7 @@ def run(_config,
#if you want to see the model structure just uncomment model summary.
#model.summary()
metrics = ['categorical_accuracy']
if task in ["segmentation", "binarization"]:
if is_loss_soft_dice:
loss = soft_dice_loss
@ -530,17 +603,18 @@ def run(_config,
loss = weighted_categorical_crossentropy(weights)
else:
loss = 'categorical_crossentropy'
metrics.append(num_connected_components_regression(0.1))
metrics.append(MeanIoU(n_classes,
name='iou',
ignore_class=0,
sparse_y_true=False,
sparse_y_pred=False))
metrics.append(ConfusionMatrix(n_classes))
else: # task == "enhancement"
loss = 'mean_squared_error'
model.compile(loss=loss,
optimizer=Adam(learning_rate=learning_rate),
metrics=['accuracy',
num_connected_components_regression(0.1),
MeanIoU(n_classes,
name='iou',
ignore_class=0,
sparse_y_true=False,
sparse_y_pred=False)])
metrics=metrics)
def _to_cv2float(img):
# rgb→bgr and uint8→float, as expected by Eynollah models