From f8dd5a328c130a82d6d06dec70a162937e78a729 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Fri, 27 Feb 2026 12:50:37 +0100 Subject: [PATCH] =?UTF-8?q?training:=20make=20plotting=2018607e0f=20more?= =?UTF-8?q?=20efficient=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - avoid control dependencies in model path - store only every 3rd sample --- src/eynollah/training/train.py | 39 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 0c624c3..30e30cb 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -115,7 +115,6 @@ def num_connected_components_regression(alpha: float): metric.__name__ = 'nCC' return metric -@tf.function def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor: """ Implements training.inference.SBBPredict.visualize_model_output for TF @@ -158,9 +157,8 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"): """ Plot the confusion matrix with matplotlib and tensorflow """ - size = cm.shape[0] - fig, ax = plt.subplots() - im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + fig, ax = plt.subplots(figsize=(10, 8), dpi=300) + im = ax.imshow(cm, vmin=0.0, vmax=1.0, interpolation='nearest', cmap=plt.cm.Blues) ax.figure.colorbar(im, ax=ax) ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), @@ -171,9 +169,6 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"): title=name, ylabel='True class', xlabel='Predicted class') - # Rotate the tick labels and set their alignment. - plt.setp(ax.get_xticklabels(), rotation=45, ha="right", - rotation_mode="anchor") # Loop over data dimensions and create text annotations. thresh = cm.max() / 2. for i in range(cm.shape[0]): @@ -200,33 +195,39 @@ class TensorBoardPlotter(TensorBoard): self.model_call = None def on_epoch_begin(self, epoch, logs=None): super().on_epoch_begin(epoch, logs=logs) + # override the model's call(), so we don't have to invest extra cycles + # to predict our samples (plotting itself can be neglected) self.model_call = self.model.call - @tf.function def new_call(inputs, **kwargs): outputs = self.model_call(inputs, **kwargs) images = plot_layout_tf(inputs, outputs) self.plot(images, training=kwargs.get('training', None), epoch=epoch) - return outputs + with tf.control_dependencies(None): + return outputs self.model.call = new_call - def on_epoch_end(self, epoch, logs=None): - # re-instate (so ModelCheckpoint does not see our override call) - self.model.call = self.model_call # force rebuild of tf.function (so Python binding for epoch gets re-evaluated) self.model.train_function = self.model.make_train_function(True) self.model.test_function = self.model.make_test_function(True) + def on_epoch_end(self, epoch, logs=None): + # re-instate (so ModelCheckpoint does not see our override call) + self.model.call = self.model_call super().on_epoch_end(epoch, logs=logs) def plot(self, images, training=None, epoch=0): if training: writer = self._train_writer - mode, step = "train", self._train_step.read_value() + mode, step = "train", self._train_step.value() else: writer = self._val_writer - mode, step = "test", self._val_step.read_value() - family = "epoch_%03d" % (1 + epoch) - with writer.as_default(): - # used to be family kwarg for tf.summary.image name prefix - with tf.name_scope(family): - tf.summary.image(mode, images, step=step, max_outputs=len(images)) + mode, step = "test", self._val_step.value() + # skip most samples, because TF's EncodePNG is so costly, + # and now ends up in the middle of our pipeline, thus causing stalls + # (cannot use max_outputs, as batch size may be too small) + if not tf.cast(step % 3, tf.bool): + with writer.as_default(): + # used to be family kwarg for tf.summary.image name prefix + family = "epoch_%03d/" % (1 + epoch) + name = family + mode + tf.summary.image(name, images, step=step, max_outputs=len(images)) def on_train_batch_end(self, batch, logs=None): if logs is not None: logs = dict(logs)