mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-14 19:31:57 +02:00
training: plot only ~ 1000 training and ~ 100 validation images
This commit is contained in:
parent
a8556f5210
commit
04da66ed73
1 changed files with 9 additions and 3 deletions
|
|
@ -158,9 +158,10 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"):
|
|||
|
||||
# plot predictions on train and test set during every epoch
|
||||
class TensorBoardPlotter(TensorBoard):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, plot_freqs, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_call = None
|
||||
self.plot_frequency_train, self.plot_frequency_val = plot_freqs
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
super().on_epoch_begin(epoch, logs=logs)
|
||||
# override the model's call(), so we don't have to invest extra cycles
|
||||
|
|
@ -183,14 +184,16 @@ class TensorBoardPlotter(TensorBoard):
|
|||
def plot(self, images, training=None, epoch=0):
|
||||
if training:
|
||||
writer = self._train_writer
|
||||
freq = self.plot_frequency_train
|
||||
mode, step = "train", self._train_step.value()
|
||||
else:
|
||||
writer = self._val_writer
|
||||
freq = self.plot_frequency_val
|
||||
mode, step = "test", self._val_step.value()
|
||||
# skip most samples, because TF's EncodePNG is so costly,
|
||||
# and now ends up in the middle of our pipeline, thus causing stalls
|
||||
# (cannot use max_outputs, as batch size may be too small)
|
||||
if not tf.cast(step % 3, tf.bool):
|
||||
if not tf.cast(step % freq, tf.bool):
|
||||
with writer.as_default():
|
||||
# used to be family kwarg for tf.summary.image name prefix
|
||||
family = "epoch_%03d/" % (1 + epoch)
|
||||
|
|
@ -634,7 +637,10 @@ def run(_config,
|
|||
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
|
||||
_log.info("validating on %d batches", valdn_steps)
|
||||
|
||||
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), profile_batch=(10, 20)),
|
||||
callbacks = [TensorBoardPlotter((max(1, train_steps * n_batch // 1000),
|
||||
max(1, valdn_steps * n_batch // 100)),
|
||||
os.path.join(dir_output, 'logs'),
|
||||
profile_batch=(10, 20)),
|
||||
SaveWeightsAfterSteps(0, dir_output, _config),
|
||||
]
|
||||
if save_interval:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue