diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 4542297..39dac1d 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -158,9 +158,10 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"): # plot predictions on train and test set during every epoch class TensorBoardPlotter(TensorBoard): - def __init__(self, *args, **kwargs): + def __init__(self, plot_freqs, *args, **kwargs): super().__init__(*args, **kwargs) self.model_call = None + self.plot_frequency_train, self.plot_frequency_val = plot_freqs def on_epoch_begin(self, epoch, logs=None): super().on_epoch_begin(epoch, logs=logs) # override the model's call(), so we don't have to invest extra cycles @@ -183,14 +184,16 @@ class TensorBoardPlotter(TensorBoard): def plot(self, images, training=None, epoch=0): if training: writer = self._train_writer + freq = self.plot_frequency_train mode, step = "train", self._train_step.value() else: writer = self._val_writer + freq = self.plot_frequency_val mode, step = "test", self._val_step.value() # skip most samples, because TF's EncodePNG is so costly, # and now ends up in the middle of our pipeline, thus causing stalls # (cannot use max_outputs, as batch size may be too small) - if not tf.cast(step % 3, tf.bool): + if not tf.cast(step % freq, tf.bool): with writer.as_default(): # used to be family kwarg for tf.summary.image name prefix family = "epoch_%03d/" % (1 + epoch) @@ -634,7 +637,10 @@ def run(_config, _log.info("training on %d batches in %d epochs", train_steps, n_epochs) _log.info("validating on %d batches", valdn_steps) - callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), profile_batch=(10, 20)), + callbacks = [TensorBoardPlotter((max(1, train_steps * n_batch // 1000), + max(1, valdn_steps * n_batch // 100)), + os.path.join(dir_output, 'logs'), + profile_batch=(10, 20)), SaveWeightsAfterSteps(0, dir_output, _config), ] if save_interval: