mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training: add metric ConfusionMatrix and plot it to TensorBoard
This commit is contained in:
parent
b6d2440ce1
commit
439ca350dd
2 changed files with 132 additions and 9 deletions
|
|
@ -1,5 +1,10 @@
|
|||
from tensorflow.keras import backend as K
|
||||
import os
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import backend as K
|
||||
from tensorflow.keras.metrics import Metric
|
||||
from tensorflow.keras.initializers import Zeros
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
|
@ -361,3 +366,47 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
|||
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
|
||||
jac = (intersection + smooth) / (sum_ - intersection + smooth)
|
||||
return (1 - jac) * smooth
|
||||
|
||||
|
||||
class ConfusionMatrix(Metric):
|
||||
def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
assert nlabels is not None
|
||||
self._nlabels = nlabels
|
||||
self._shape = (self._nlabels, self._nlabels)
|
||||
self._matrix = self.add_weight(name, shape=self._shape,
|
||||
initializer=Zeros)
|
||||
assert nrm in ("all", "true", "pred", "none")
|
||||
self._nrm = nrm
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||
y_true = tf.math.argmax(y_true, axis=-1)
|
||||
|
||||
y_pred = tf.reshape(y_pred, shape=(-1,))
|
||||
y_true = tf.reshape(y_true, shape=(-1,))
|
||||
|
||||
y_pred.shape.assert_is_compatible_with(y_true.shape)
|
||||
confusion = tf.math.confusion_matrix(y_true, y_pred, num_classes=self._nlabels, dtype=self._dtype)
|
||||
|
||||
return self._matrix.assign_add(confusion)
|
||||
|
||||
def result(self):
|
||||
"""normalize"""
|
||||
if self._nrm == "all":
|
||||
denom = tf.math.reduce_sum(self._matrix, axis=(0, 1))
|
||||
elif self._nrm == "true":
|
||||
denom = tf.math.reduce_sum(self._matrix, axis=1, keepdims=True)
|
||||
elif self._nrm == "pred":
|
||||
denom = tf.math.reduce_sum(self._matrix, axis=0, keepdims=True)
|
||||
else:
|
||||
denom = tf.constant(1.0)
|
||||
return tf.math.divide_no_nan(self._matrix, denom)
|
||||
|
||||
def reset_state(self):
|
||||
for v in self.variables:
|
||||
v.assign(tf.zeros(shape=self._shape))
|
||||
|
||||
def get_config(self):
|
||||
return dict(nlabels=self._nlabels,
|
||||
**super().get_config())
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
import io
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
|
@ -21,10 +22,12 @@ from sacred.config import create_captured_function
|
|||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from matplotlib import pyplot as plt # for plot_confusion_matrix
|
||||
|
||||
from .metrics import (
|
||||
soft_dice_loss,
|
||||
weighted_categorical_crossentropy
|
||||
weighted_categorical_crossentropy,
|
||||
ConfusionMatrix,
|
||||
)
|
||||
from .models import (
|
||||
PatchEncoder,
|
||||
|
|
@ -151,6 +154,45 @@ def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
|||
weighted = image * 0.9 + layout * 0.1
|
||||
return tf.cast(weighted, tf.uint8)
|
||||
|
||||
def plot_confusion_matrix(cm, name="Confusion Matrix"):
|
||||
"""
|
||||
Plot the confusion matrix with matplotlib and tensorflow
|
||||
"""
|
||||
size = cm.shape[0]
|
||||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||||
ax.figure.colorbar(im, ax=ax)
|
||||
ax.set(xticks=np.arange(cm.shape[1]),
|
||||
yticks=np.arange(cm.shape[0]),
|
||||
xlim=[-0.5, cm.shape[1] - 0.5],
|
||||
ylim=[-0.5, cm.shape[0] - 0.5],
|
||||
#xticklabels=labels,
|
||||
#yticklabels=labels,
|
||||
title=name,
|
||||
ylabel='True class',
|
||||
xlabel='Predicted class')
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
||||
rotation_mode="anchor")
|
||||
# Loop over data dimensions and create text annotations.
|
||||
thresh = cm.max() / 2.
|
||||
for i in range(cm.shape[0]):
|
||||
for j in range(cm.shape[1]):
|
||||
ax.text(j, i, format(cm[i, j], ".2f"),
|
||||
ha="center", va="center",
|
||||
color="white" if cm[i, j] > thresh else "black")
|
||||
fig.tight_layout()
|
||||
# convert to PNG
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format='png')
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
# Convert PNG buffer to TF image
|
||||
image = tf.image.decode_png(buf.getvalue(), channels=4)
|
||||
# Add the batch dimension
|
||||
image = tf.expand_dims(image, 0)
|
||||
return image
|
||||
|
||||
# plot predictions on train and test set during every epoch
|
||||
class TensorBoardPlotter(TensorBoard):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -185,6 +227,36 @@ class TensorBoardPlotter(TensorBoard):
|
|||
# used to be family kwarg for tf.summary.image name prefix
|
||||
with tf.name_scope(family):
|
||||
tf.summary.image(mode, images, step=step, max_outputs=len(images))
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
if logs is not None:
|
||||
logs = dict(logs)
|
||||
# cannot be logged as scalar:
|
||||
logs.pop('confusion_matrix', None)
|
||||
super().on_train_batch_end(batch, logs)
|
||||
def on_test_end(self, logs=None):
|
||||
if logs is not None:
|
||||
logs = dict(logs)
|
||||
# cannot be logged as scalar:
|
||||
logs.pop('confusion_matrix', None)
|
||||
super().on_test_end(logs)
|
||||
def _log_epoch_metrics(self, epoch, logs):
|
||||
if not logs:
|
||||
return
|
||||
logs = dict(logs)
|
||||
# cannot be logged as scalar:
|
||||
train_matrix = logs.pop('confusion_matrix', None)
|
||||
val_matrix = logs.pop('val_confusion_matrix', None)
|
||||
super()._log_epoch_metrics(epoch, logs)
|
||||
# now plot confusion_matrix
|
||||
with tf.summary.record_if(True):
|
||||
if train_matrix is not None:
|
||||
train_image = plot_confusion_matrix(train_matrix)
|
||||
with self._train_writer.as_default():
|
||||
tf.summary.image("confusion_matrix", train_image, step=epoch)
|
||||
if val_matrix is not None:
|
||||
val_image = plot_confusion_matrix(val_matrix)
|
||||
with self._val_writer.as_default():
|
||||
tf.summary.image("confusion_matrix", val_image, step=epoch)
|
||||
|
||||
def get_dirs_or_files(input_data):
|
||||
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
||||
|
|
@ -523,6 +595,7 @@ def run(_config,
|
|||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
metrics = ['categorical_accuracy']
|
||||
if task in ["segmentation", "binarization"]:
|
||||
if is_loss_soft_dice:
|
||||
loss = soft_dice_loss
|
||||
|
|
@ -530,17 +603,18 @@ def run(_config,
|
|||
loss = weighted_categorical_crossentropy(weights)
|
||||
else:
|
||||
loss = 'categorical_crossentropy'
|
||||
metrics.append(num_connected_components_regression(0.1))
|
||||
metrics.append(MeanIoU(n_classes,
|
||||
name='iou',
|
||||
ignore_class=0,
|
||||
sparse_y_true=False,
|
||||
sparse_y_pred=False))
|
||||
metrics.append(ConfusionMatrix(n_classes))
|
||||
else: # task == "enhancement"
|
||||
loss = 'mean_squared_error'
|
||||
model.compile(loss=loss,
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
metrics=['accuracy',
|
||||
num_connected_components_regression(0.1),
|
||||
MeanIoU(n_classes,
|
||||
name='iou',
|
||||
ignore_class=0,
|
||||
sparse_y_true=False,
|
||||
sparse_y_pred=False)])
|
||||
metrics=metrics)
|
||||
|
||||
def _to_cv2float(img):
|
||||
# rgb→bgr and uint8→float, as expected by Eynollah models
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue