training: make plotting 18607e0f more efficient…

- avoid control dependencies in model path
- store only every 3rd sample
This commit is contained in:
Robert Sachunsky 2026-02-27 12:50:37 +01:00
parent 2d5de8e595
commit f8dd5a328c

View file

@ -115,7 +115,6 @@ def num_connected_components_regression(alpha: float):
metric.__name__ = 'nCC'
return metric
@tf.function
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
"""
Implements training.inference.SBBPredict.visualize_model_output for TF
@ -158,9 +157,8 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"):
"""
Plot the confusion matrix with matplotlib and tensorflow
"""
size = cm.shape[0]
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
fig, ax = plt.subplots(figsize=(10, 8), dpi=300)
im = ax.imshow(cm, vmin=0.0, vmax=1.0, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
@ -171,9 +169,6 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"):
title=name,
ylabel='True class',
xlabel='Predicted class')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
@ -200,33 +195,39 @@ class TensorBoardPlotter(TensorBoard):
self.model_call = None
def on_epoch_begin(self, epoch, logs=None):
super().on_epoch_begin(epoch, logs=logs)
# override the model's call(), so we don't have to invest extra cycles
# to predict our samples (plotting itself can be neglected)
self.model_call = self.model.call
@tf.function
def new_call(inputs, **kwargs):
outputs = self.model_call(inputs, **kwargs)
images = plot_layout_tf(inputs, outputs)
self.plot(images, training=kwargs.get('training', None), epoch=epoch)
return outputs
with tf.control_dependencies(None):
return outputs
self.model.call = new_call
def on_epoch_end(self, epoch, logs=None):
# re-instate (so ModelCheckpoint does not see our override call)
self.model.call = self.model_call
# force rebuild of tf.function (so Python binding for epoch gets re-evaluated)
self.model.train_function = self.model.make_train_function(True)
self.model.test_function = self.model.make_test_function(True)
def on_epoch_end(self, epoch, logs=None):
# re-instate (so ModelCheckpoint does not see our override call)
self.model.call = self.model_call
super().on_epoch_end(epoch, logs=logs)
def plot(self, images, training=None, epoch=0):
if training:
writer = self._train_writer
mode, step = "train", self._train_step.read_value()
mode, step = "train", self._train_step.value()
else:
writer = self._val_writer
mode, step = "test", self._val_step.read_value()
family = "epoch_%03d" % (1 + epoch)
with writer.as_default():
# used to be family kwarg for tf.summary.image name prefix
with tf.name_scope(family):
tf.summary.image(mode, images, step=step, max_outputs=len(images))
mode, step = "test", self._val_step.value()
# skip most samples, because TF's EncodePNG is so costly,
# and now ends up in the middle of our pipeline, thus causing stalls
# (cannot use max_outputs, as batch size may be too small)
if not tf.cast(step % 3, tf.bool):
with writer.as_default():
# used to be family kwarg for tf.summary.image name prefix
family = "epoch_%03d/" % (1 + epoch)
name = family + mode
tf.summary.image(name, images, step=step, max_outputs=len(images))
def on_train_batch_end(self, batch, logs=None):
if logs is not None:
logs = dict(logs)