mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: make plotting 18607e0f more efficient…
- avoid control dependencies in model path - store only every 3rd sample
This commit is contained in:
parent
2d5de8e595
commit
f8dd5a328c
1 changed files with 20 additions and 19 deletions
|
|
@ -115,7 +115,6 @@ def num_connected_components_regression(alpha: float):
|
||||||
metric.__name__ = 'nCC'
|
metric.__name__ = 'nCC'
|
||||||
return metric
|
return metric
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Implements training.inference.SBBPredict.visualize_model_output for TF
|
Implements training.inference.SBBPredict.visualize_model_output for TF
|
||||||
|
|
@ -158,9 +157,8 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"):
|
||||||
"""
|
"""
|
||||||
Plot the confusion matrix with matplotlib and tensorflow
|
Plot the confusion matrix with matplotlib and tensorflow
|
||||||
"""
|
"""
|
||||||
size = cm.shape[0]
|
fig, ax = plt.subplots(figsize=(10, 8), dpi=300)
|
||||||
fig, ax = plt.subplots()
|
im = ax.imshow(cm, vmin=0.0, vmax=1.0, interpolation='nearest', cmap=plt.cm.Blues)
|
||||||
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
|
||||||
ax.figure.colorbar(im, ax=ax)
|
ax.figure.colorbar(im, ax=ax)
|
||||||
ax.set(xticks=np.arange(cm.shape[1]),
|
ax.set(xticks=np.arange(cm.shape[1]),
|
||||||
yticks=np.arange(cm.shape[0]),
|
yticks=np.arange(cm.shape[0]),
|
||||||
|
|
@ -171,9 +169,6 @@ def plot_confusion_matrix(cm, name="Confusion Matrix"):
|
||||||
title=name,
|
title=name,
|
||||||
ylabel='True class',
|
ylabel='True class',
|
||||||
xlabel='Predicted class')
|
xlabel='Predicted class')
|
||||||
# Rotate the tick labels and set their alignment.
|
|
||||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
|
|
||||||
rotation_mode="anchor")
|
|
||||||
# Loop over data dimensions and create text annotations.
|
# Loop over data dimensions and create text annotations.
|
||||||
thresh = cm.max() / 2.
|
thresh = cm.max() / 2.
|
||||||
for i in range(cm.shape[0]):
|
for i in range(cm.shape[0]):
|
||||||
|
|
@ -200,33 +195,39 @@ class TensorBoardPlotter(TensorBoard):
|
||||||
self.model_call = None
|
self.model_call = None
|
||||||
def on_epoch_begin(self, epoch, logs=None):
|
def on_epoch_begin(self, epoch, logs=None):
|
||||||
super().on_epoch_begin(epoch, logs=logs)
|
super().on_epoch_begin(epoch, logs=logs)
|
||||||
|
# override the model's call(), so we don't have to invest extra cycles
|
||||||
|
# to predict our samples (plotting itself can be neglected)
|
||||||
self.model_call = self.model.call
|
self.model_call = self.model.call
|
||||||
@tf.function
|
|
||||||
def new_call(inputs, **kwargs):
|
def new_call(inputs, **kwargs):
|
||||||
outputs = self.model_call(inputs, **kwargs)
|
outputs = self.model_call(inputs, **kwargs)
|
||||||
images = plot_layout_tf(inputs, outputs)
|
images = plot_layout_tf(inputs, outputs)
|
||||||
self.plot(images, training=kwargs.get('training', None), epoch=epoch)
|
self.plot(images, training=kwargs.get('training', None), epoch=epoch)
|
||||||
return outputs
|
with tf.control_dependencies(None):
|
||||||
|
return outputs
|
||||||
self.model.call = new_call
|
self.model.call = new_call
|
||||||
def on_epoch_end(self, epoch, logs=None):
|
|
||||||
# re-instate (so ModelCheckpoint does not see our override call)
|
|
||||||
self.model.call = self.model_call
|
|
||||||
# force rebuild of tf.function (so Python binding for epoch gets re-evaluated)
|
# force rebuild of tf.function (so Python binding for epoch gets re-evaluated)
|
||||||
self.model.train_function = self.model.make_train_function(True)
|
self.model.train_function = self.model.make_train_function(True)
|
||||||
self.model.test_function = self.model.make_test_function(True)
|
self.model.test_function = self.model.make_test_function(True)
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
# re-instate (so ModelCheckpoint does not see our override call)
|
||||||
|
self.model.call = self.model_call
|
||||||
super().on_epoch_end(epoch, logs=logs)
|
super().on_epoch_end(epoch, logs=logs)
|
||||||
def plot(self, images, training=None, epoch=0):
|
def plot(self, images, training=None, epoch=0):
|
||||||
if training:
|
if training:
|
||||||
writer = self._train_writer
|
writer = self._train_writer
|
||||||
mode, step = "train", self._train_step.read_value()
|
mode, step = "train", self._train_step.value()
|
||||||
else:
|
else:
|
||||||
writer = self._val_writer
|
writer = self._val_writer
|
||||||
mode, step = "test", self._val_step.read_value()
|
mode, step = "test", self._val_step.value()
|
||||||
family = "epoch_%03d" % (1 + epoch)
|
# skip most samples, because TF's EncodePNG is so costly,
|
||||||
with writer.as_default():
|
# and now ends up in the middle of our pipeline, thus causing stalls
|
||||||
# used to be family kwarg for tf.summary.image name prefix
|
# (cannot use max_outputs, as batch size may be too small)
|
||||||
with tf.name_scope(family):
|
if not tf.cast(step % 3, tf.bool):
|
||||||
tf.summary.image(mode, images, step=step, max_outputs=len(images))
|
with writer.as_default():
|
||||||
|
# used to be family kwarg for tf.summary.image name prefix
|
||||||
|
family = "epoch_%03d/" % (1 + epoch)
|
||||||
|
name = family + mode
|
||||||
|
tf.summary.image(name, images, step=step, max_outputs=len(images))
|
||||||
def on_train_batch_end(self, batch, logs=None):
|
def on_train_batch_end(self, batch, logs=None):
|
||||||
if logs is not None:
|
if logs is not None:
|
||||||
logs = dict(logs)
|
logs = dict(logs)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue