mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: make plotting 18607e0f more efficient…
- avoid control dependencies in model path - store only every 3rd sample
This commit is contained in:
parent
2d5de8e595
commit
f8dd5a328c
1 changed files with 20 additions and 19 deletions
|
|
@ -115,7 +115,6 @@ def num_connected_components_regression(alpha: float):
|
|||
metric.__name__ = 'nCC'
|
||||
return metric
|
||||
|
||||
@tf.function
|
||||
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Implements training.inference.SBBPredict.visualize_model_output for TF
|
||||
|
|
@ -158,9 +157,8 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"):
|
|||
"""
|
||||
Plot the confusion matrix with matplotlib and tensorflow
|
||||
"""
|
||||
size = cm.shape[0]
|
||||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||||
fig, ax = plt.subplots(figsize=(10, 8), dpi=300)
|
||||
im = ax.imshow(cm, vmin=0.0, vmax=1.0, interpolation='nearest', cmap=plt.cm.Blues)
|
||||
ax.figure.colorbar(im, ax=ax)
|
||||
ax.set(xticks=np.arange(cm.shape[1]),
|
||||
yticks=np.arange(cm.shape[0]),
|
||||
|
|
@ -171,9 +169,6 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"):
|
|||
title=name,
|
||||
ylabel='True class',
|
||||
xlabel='Predicted class')
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
||||
rotation_mode="anchor")
|
||||
# Loop over data dimensions and create text annotations.
|
||||
thresh = cm.max() / 2.
|
||||
for i in range(cm.shape[0]):
|
||||
|
|
@ -200,33 +195,39 @@ class TensorBoardPlotter(TensorBoard):
|
|||
self.model_call = None
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
super().on_epoch_begin(epoch, logs=logs)
|
||||
# override the model's call(), so we don't have to invest extra cycles
|
||||
# to predict our samples (plotting itself can be neglected)
|
||||
self.model_call = self.model.call
|
||||
@tf.function
|
||||
def new_call(inputs, **kwargs):
|
||||
outputs = self.model_call(inputs, **kwargs)
|
||||
images = plot_layout_tf(inputs, outputs)
|
||||
self.plot(images, training=kwargs.get('training', None), epoch=epoch)
|
||||
return outputs
|
||||
with tf.control_dependencies(None):
|
||||
return outputs
|
||||
self.model.call = new_call
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
# re-instate (so ModelCheckpoint does not see our override call)
|
||||
self.model.call = self.model_call
|
||||
# force rebuild of tf.function (so Python binding for epoch gets re-evaluated)
|
||||
self.model.train_function = self.model.make_train_function(True)
|
||||
self.model.test_function = self.model.make_test_function(True)
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
# re-instate (so ModelCheckpoint does not see our override call)
|
||||
self.model.call = self.model_call
|
||||
super().on_epoch_end(epoch, logs=logs)
|
||||
def plot(self, images, training=None, epoch=0):
|
||||
if training:
|
||||
writer = self._train_writer
|
||||
mode, step = "train", self._train_step.read_value()
|
||||
mode, step = "train", self._train_step.value()
|
||||
else:
|
||||
writer = self._val_writer
|
||||
mode, step = "test", self._val_step.read_value()
|
||||
family = "epoch_%03d" % (1 + epoch)
|
||||
with writer.as_default():
|
||||
# used to be family kwarg for tf.summary.image name prefix
|
||||
with tf.name_scope(family):
|
||||
tf.summary.image(mode, images, step=step, max_outputs=len(images))
|
||||
mode, step = "test", self._val_step.value()
|
||||
# skip most samples, because TF's EncodePNG is so costly,
|
||||
# and now ends up in the middle of our pipeline, thus causing stalls
|
||||
# (cannot use max_outputs, as batch size may be too small)
|
||||
if not tf.cast(step % 3, tf.bool):
|
||||
with writer.as_default():
|
||||
# used to be family kwarg for tf.summary.image name prefix
|
||||
family = "epoch_%03d/" % (1 + epoch)
|
||||
name = family + mode
|
||||
tf.summary.image(name, images, step=step, max_outputs=len(images))
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
if logs is not None:
|
||||
logs = dict(logs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue