mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: add metric ConfusionMatrix and plot it to TensorBoard
This commit is contained in:
parent
b6d2440ce1
commit
439ca350dd
2 changed files with 132 additions and 9 deletions
|
|
@ -1,5 +1,10 @@
|
||||||
from tensorflow.keras import backend as K
|
import os
|
||||||
|
|
||||||
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras import backend as K
|
||||||
|
from tensorflow.keras.metrics import Metric
|
||||||
|
from tensorflow.keras.initializers import Zeros
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -361,3 +366,47 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
||||||
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
|
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
|
||||||
jac = (intersection + smooth) / (sum_ - intersection + smooth)
|
jac = (intersection + smooth) / (sum_ - intersection + smooth)
|
||||||
return (1 - jac) * smooth
|
return (1 - jac) * smooth
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix(Metric):
|
||||||
|
def __init__(self, nlabels=None, nrm="all", name="confusion_matrix", dtype=tf.float32):
|
||||||
|
super().__init__(name=name, dtype=dtype)
|
||||||
|
assert nlabels is not None
|
||||||
|
self._nlabels = nlabels
|
||||||
|
self._shape = (self._nlabels, self._nlabels)
|
||||||
|
self._matrix = self.add_weight(name, shape=self._shape,
|
||||||
|
initializer=Zeros)
|
||||||
|
assert nrm in ("all", "true", "pred", "none")
|
||||||
|
self._nrm = nrm
|
||||||
|
|
||||||
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
|
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||||
|
y_true = tf.math.argmax(y_true, axis=-1)
|
||||||
|
|
||||||
|
y_pred = tf.reshape(y_pred, shape=(-1,))
|
||||||
|
y_true = tf.reshape(y_true, shape=(-1,))
|
||||||
|
|
||||||
|
y_pred.shape.assert_is_compatible_with(y_true.shape)
|
||||||
|
confusion = tf.math.confusion_matrix(y_true, y_pred, num_classes=self._nlabels, dtype=self._dtype)
|
||||||
|
|
||||||
|
return self._matrix.assign_add(confusion)
|
||||||
|
|
||||||
|
def result(self):
|
||||||
|
"""normalize"""
|
||||||
|
if self._nrm == "all":
|
||||||
|
denom = tf.math.reduce_sum(self._matrix, axis=(0, 1))
|
||||||
|
elif self._nrm == "true":
|
||||||
|
denom = tf.math.reduce_sum(self._matrix, axis=1, keepdims=True)
|
||||||
|
elif self._nrm == "pred":
|
||||||
|
denom = tf.math.reduce_sum(self._matrix, axis=0, keepdims=True)
|
||||||
|
else:
|
||||||
|
denom = tf.constant(1.0)
|
||||||
|
return tf.math.divide_no_nan(self._matrix, denom)
|
||||||
|
|
||||||
|
def reset_state(self):
|
||||||
|
for v in self.variables:
|
||||||
|
v.assign(tf.zeros(shape=self._shape))
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return dict(nlabels=self._nlabels,
|
||||||
|
**super().get_config())
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
@ -21,10 +22,12 @@ from sacred.config import create_captured_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
from matplotlib import pyplot as plt # for plot_confusion_matrix
|
||||||
|
|
||||||
from .metrics import (
|
from .metrics import (
|
||||||
soft_dice_loss,
|
soft_dice_loss,
|
||||||
weighted_categorical_crossentropy
|
weighted_categorical_crossentropy,
|
||||||
|
ConfusionMatrix,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
PatchEncoder,
|
PatchEncoder,
|
||||||
|
|
@ -151,6 +154,45 @@ def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
||||||
weighted = image * 0.9 + layout * 0.1
|
weighted = image * 0.9 + layout * 0.1
|
||||||
return tf.cast(weighted, tf.uint8)
|
return tf.cast(weighted, tf.uint8)
|
||||||
|
|
||||||
|
def plot_confusion_matrix(cm, name="Confusion Matrix"):
|
||||||
|
"""
|
||||||
|
Plot the confusion matrix with matplotlib and tensorflow
|
||||||
|
"""
|
||||||
|
size = cm.shape[0]
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||||||
|
ax.figure.colorbar(im, ax=ax)
|
||||||
|
ax.set(xticks=np.arange(cm.shape[1]),
|
||||||
|
yticks=np.arange(cm.shape[0]),
|
||||||
|
xlim=[-0.5, cm.shape[1] - 0.5],
|
||||||
|
ylim=[-0.5, cm.shape[0] - 0.5],
|
||||||
|
#xticklabels=labels,
|
||||||
|
#yticklabels=labels,
|
||||||
|
title=name,
|
||||||
|
ylabel='True class',
|
||||||
|
xlabel='Predicted class')
|
||||||
|
# Rotate the tick labels and set their alignment.
|
||||||
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
||||||
|
rotation_mode="anchor")
|
||||||
|
# Loop over data dimensions and create text annotations.
|
||||||
|
thresh = cm.max() / 2.
|
||||||
|
for i in range(cm.shape[0]):
|
||||||
|
for j in range(cm.shape[1]):
|
||||||
|
ax.text(j, i, format(cm[i, j], ".2f"),
|
||||||
|
ha="center", va="center",
|
||||||
|
color="white" if cm[i, j] > thresh else "black")
|
||||||
|
fig.tight_layout()
|
||||||
|
# convert to PNG
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png')
|
||||||
|
plt.close(fig)
|
||||||
|
buf.seek(0)
|
||||||
|
# Convert PNG buffer to TF image
|
||||||
|
image = tf.image.decode_png(buf.getvalue(), channels=4)
|
||||||
|
# Add the batch dimension
|
||||||
|
image = tf.expand_dims(image, 0)
|
||||||
|
return image
|
||||||
|
|
||||||
# plot predictions on train and test set during every epoch
|
# plot predictions on train and test set during every epoch
|
||||||
class TensorBoardPlotter(TensorBoard):
|
class TensorBoardPlotter(TensorBoard):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
@ -185,6 +227,36 @@ class TensorBoardPlotter(TensorBoard):
|
||||||
# used to be family kwarg for tf.summary.image name prefix
|
# used to be family kwarg for tf.summary.image name prefix
|
||||||
with tf.name_scope(family):
|
with tf.name_scope(family):
|
||||||
tf.summary.image(mode, images, step=step, max_outputs=len(images))
|
tf.summary.image(mode, images, step=step, max_outputs=len(images))
|
||||||
|
def on_train_batch_end(self, batch, logs=None):
|
||||||
|
if logs is not None:
|
||||||
|
logs = dict(logs)
|
||||||
|
# cannot be logged as scalar:
|
||||||
|
logs.pop('confusion_matrix', None)
|
||||||
|
super().on_train_batch_end(batch, logs)
|
||||||
|
def on_test_end(self, logs=None):
|
||||||
|
if logs is not None:
|
||||||
|
logs = dict(logs)
|
||||||
|
# cannot be logged as scalar:
|
||||||
|
logs.pop('confusion_matrix', None)
|
||||||
|
super().on_test_end(logs)
|
||||||
|
def _log_epoch_metrics(self, epoch, logs):
|
||||||
|
if not logs:
|
||||||
|
return
|
||||||
|
logs = dict(logs)
|
||||||
|
# cannot be logged as scalar:
|
||||||
|
train_matrix = logs.pop('confusion_matrix', None)
|
||||||
|
val_matrix = logs.pop('val_confusion_matrix', None)
|
||||||
|
super()._log_epoch_metrics(epoch, logs)
|
||||||
|
# now plot confusion_matrix
|
||||||
|
with tf.summary.record_if(True):
|
||||||
|
if train_matrix is not None:
|
||||||
|
train_image = plot_confusion_matrix(train_matrix)
|
||||||
|
with self._train_writer.as_default():
|
||||||
|
tf.summary.image("confusion_matrix", train_image, step=epoch)
|
||||||
|
if val_matrix is not None:
|
||||||
|
val_image = plot_confusion_matrix(val_matrix)
|
||||||
|
with self._val_writer.as_default():
|
||||||
|
tf.summary.image("confusion_matrix", val_image, step=epoch)
|
||||||
|
|
||||||
def get_dirs_or_files(input_data):
|
def get_dirs_or_files(input_data):
|
||||||
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
||||||
|
|
@ -523,6 +595,7 @@ def run(_config,
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
#model.summary()
|
#model.summary()
|
||||||
|
|
||||||
|
metrics = ['categorical_accuracy']
|
||||||
if task in ["segmentation", "binarization"]:
|
if task in ["segmentation", "binarization"]:
|
||||||
if is_loss_soft_dice:
|
if is_loss_soft_dice:
|
||||||
loss = soft_dice_loss
|
loss = soft_dice_loss
|
||||||
|
|
@ -530,17 +603,18 @@ def run(_config,
|
||||||
loss = weighted_categorical_crossentropy(weights)
|
loss = weighted_categorical_crossentropy(weights)
|
||||||
else:
|
else:
|
||||||
loss = 'categorical_crossentropy'
|
loss = 'categorical_crossentropy'
|
||||||
|
metrics.append(num_connected_components_regression(0.1))
|
||||||
|
metrics.append(MeanIoU(n_classes,
|
||||||
|
name='iou',
|
||||||
|
ignore_class=0,
|
||||||
|
sparse_y_true=False,
|
||||||
|
sparse_y_pred=False))
|
||||||
|
metrics.append(ConfusionMatrix(n_classes))
|
||||||
else: # task == "enhancement"
|
else: # task == "enhancement"
|
||||||
loss = 'mean_squared_error'
|
loss = 'mean_squared_error'
|
||||||
model.compile(loss=loss,
|
model.compile(loss=loss,
|
||||||
optimizer=Adam(learning_rate=learning_rate),
|
optimizer=Adam(learning_rate=learning_rate),
|
||||||
metrics=['accuracy',
|
metrics=metrics)
|
||||||
num_connected_components_regression(0.1),
|
|
||||||
MeanIoU(n_classes,
|
|
||||||
name='iou',
|
|
||||||
ignore_class=0,
|
|
||||||
sparse_y_true=False,
|
|
||||||
sparse_y_pred=False)])
|
|
||||||
|
|
||||||
def _to_cv2float(img):
|
def _to_cv2float(img):
|
||||||
# rgb→bgr and uint8→float, as expected by Eynollah models
|
# rgb→bgr and uint8→float, as expected by Eynollah models
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue