mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: plot predictions to TB logs along with training/testing
This commit is contained in:
parent
56833b3f55
commit
18607e0f48
1 changed files with 76 additions and 2 deletions
|
|
@ -74,6 +74,79 @@ def configuration():
|
|||
except:
|
||||
print("no GPU device available", file=sys.stderr)
|
||||
|
||||
@tf.function
|
||||
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Implements training.inference.SBBPredict.visualize_model_output for TF
|
||||
(effectively plotting the layout segmentation map on the input image).
|
||||
|
||||
In doing so, also converts:
|
||||
- from Eynollah's BGR/float on the input side
|
||||
- to std RGB/int format on the output side
|
||||
"""
|
||||
# in_: [B, H, W, 3] (BGR float)
|
||||
image = in_[..., ::-1] * 255
|
||||
# out: [B, H, W, C]
|
||||
lab = tf.math.argmax(out, axis=-1)
|
||||
# lab: [B, H, W]
|
||||
colors = tf.constant([[255, 255, 255],
|
||||
[255, 0, 0],
|
||||
[255, 125, 0],
|
||||
[255, 0, 125],
|
||||
[125, 125, 125],
|
||||
[125, 125, 0],
|
||||
[0, 125, 255],
|
||||
[0, 125, 0],
|
||||
[125, 125, 125],
|
||||
[0, 125, 255],
|
||||
[125, 0, 125],
|
||||
[0, 255, 0],
|
||||
[0, 0, 255],
|
||||
[0, 255, 255],
|
||||
[255, 125, 125],
|
||||
[255, 0, 255]])
|
||||
layout = tf.gather(colors, lab)
|
||||
# layout: [B, H, W, 3]
|
||||
image = tf.cast(image, tf.float32)
|
||||
layout = tf.cast(layout, tf.float32)
|
||||
#weighted = image * 0.5 + layout * 0.1 (too dark)
|
||||
weighted = image * 0.9 + layout * 0.1
|
||||
return tf.cast(weighted, tf.uint8)
|
||||
|
||||
# plot predictions on train and test set during every epoch
|
||||
class TensorBoardPlotter(TensorBoard):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_call = None
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
super().on_epoch_begin(epoch, logs=logs)
|
||||
self.model_call = self.model.call
|
||||
@tf.function
|
||||
def new_call(inputs, **kwargs):
|
||||
outputs = self.model_call(inputs, **kwargs)
|
||||
images = plot_layout_tf(inputs, outputs)
|
||||
self.plot(images, training=kwargs.get('training', None), epoch=epoch)
|
||||
return outputs
|
||||
self.model.call = new_call
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
# re-instate (so ModelCheckpoint does not see our override call)
|
||||
self.model.call = self.model_call
|
||||
# force rebuild of tf.function (so Python binding for epoch gets re-evaluated)
|
||||
self.model.train_function = self.model.make_train_function(True)
|
||||
self.model.test_function = self.model.make_test_function(True)
|
||||
super().on_epoch_end(epoch, logs=logs)
|
||||
def plot(self, images, training=None, epoch=0):
|
||||
if training:
|
||||
writer = self._train_writer
|
||||
mode, step = "train", self._train_step.read_value()
|
||||
else:
|
||||
writer = self._val_writer
|
||||
mode, step = "test", self._val_step.read_value()
|
||||
family = "epoch_%03d" % (1 + epoch)
|
||||
with writer.as_default():
|
||||
# used to be family kwarg for tf.summary.image name prefix
|
||||
with tf.name_scope(family):
|
||||
tf.summary.image(mode, images, step=step, max_outputs=len(images))
|
||||
|
||||
def get_dirs_or_files(input_data):
|
||||
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
||||
|
|
@ -471,14 +544,15 @@ def run(_config,
|
|||
lab_gen = lab_gen.map(_to_categorical)
|
||||
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
||||
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||
SaveWeightsAfterSteps(0, dir_output, _config)]
|
||||
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
||||
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
|
||||
valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch
|
||||
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
|
||||
_log.info("validating on %d batches", valdn_steps)
|
||||
|
||||
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||
SaveWeightsAfterSteps(0, dir_output, _config),
|
||||
]
|
||||
if save_interval:
|
||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||
model.fit(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue