training: plot only ~ 1000 training and ~ 100 validation images

This commit is contained in:
Robert Sachunsky 2026-03-30 13:34:05 +02:00
parent a8556f5210
commit 04da66ed73

View file

@ -158,9 +158,10 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"):
# plot predictions on train and test set during every epoch
class TensorBoardPlotter(TensorBoard):
def __init__(self, *args, **kwargs):
def __init__(self, plot_freqs, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_call = None
self.plot_frequency_train, self.plot_frequency_val = plot_freqs
def on_epoch_begin(self, epoch, logs=None):
super().on_epoch_begin(epoch, logs=logs)
# override the model's call(), so we don't have to invest extra cycles
@ -183,14 +184,16 @@ class TensorBoardPlotter(TensorBoard):
def plot(self, images, training=None, epoch=0):
if training:
writer = self._train_writer
freq = self.plot_frequency_train
mode, step = "train", self._train_step.value()
else:
writer = self._val_writer
freq = self.plot_frequency_val
mode, step = "test", self._val_step.value()
# skip most samples, because TF's EncodePNG is so costly,
# and now ends up in the middle of our pipeline, thus causing stalls
# (cannot use max_outputs, as batch size may be too small)
if not tf.cast(step % 3, tf.bool):
if not tf.cast(step % freq, tf.bool):
with writer.as_default():
# used to be family kwarg for tf.summary.image name prefix
family = "epoch_%03d/" % (1 + epoch)
@ -634,7 +637,10 @@ def run(_config,
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
_log.info("validating on %d batches", valdn_steps)
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), profile_batch=(10, 20)),
callbacks = [TensorBoardPlotter((max(1, train_steps * n_batch // 1000),
max(1, valdn_steps * n_batch // 100)),
os.path.join(dir_output, 'logs'),
profile_batch=(10, 20)),
SaveWeightsAfterSteps(0, dir_output, _config),
]
if save_interval: