training: plot predictions to TB logs along with training/testing

This commit is contained in:
Robert Sachunsky 2026-02-24 17:00:48 +01:00
parent 56833b3f55
commit 18607e0f48

View file

@ -74,6 +74,79 @@ def configuration():
except: except:
print("no GPU device available", file=sys.stderr) print("no GPU device available", file=sys.stderr)
@tf.function
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
"""
Implements training.inference.SBBPredict.visualize_model_output for TF
(effectively plotting the layout segmentation map on the input image).
In doing so, also converts:
- from Eynollah's BGR/float on the input side
- to std RGB/int format on the output side
"""
# in_: [B, H, W, 3] (BGR float)
image = in_[..., ::-1] * 255
# out: [B, H, W, C]
lab = tf.math.argmax(out, axis=-1)
# lab: [B, H, W]
colors = tf.constant([[255, 255, 255],
[255, 0, 0],
[255, 125, 0],
[255, 0, 125],
[125, 125, 125],
[125, 125, 0],
[0, 125, 255],
[0, 125, 0],
[125, 125, 125],
[0, 125, 255],
[125, 0, 125],
[0, 255, 0],
[0, 0, 255],
[0, 255, 255],
[255, 125, 125],
[255, 0, 255]])
layout = tf.gather(colors, lab)
# layout: [B, H, W, 3]
image = tf.cast(image, tf.float32)
layout = tf.cast(layout, tf.float32)
#weighted = image * 0.5 + layout * 0.1 (too dark)
weighted = image * 0.9 + layout * 0.1
return tf.cast(weighted, tf.uint8)
# plot predictions on train and test set during every epoch
class TensorBoardPlotter(TensorBoard):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_call = None
def on_epoch_begin(self, epoch, logs=None):
super().on_epoch_begin(epoch, logs=logs)
self.model_call = self.model.call
@tf.function
def new_call(inputs, **kwargs):
outputs = self.model_call(inputs, **kwargs)
images = plot_layout_tf(inputs, outputs)
self.plot(images, training=kwargs.get('training', None), epoch=epoch)
return outputs
self.model.call = new_call
def on_epoch_end(self, epoch, logs=None):
# re-instate (so ModelCheckpoint does not see our override call)
self.model.call = self.model_call
# force rebuild of tf.function (so Python binding for epoch gets re-evaluated)
self.model.train_function = self.model.make_train_function(True)
self.model.test_function = self.model.make_test_function(True)
super().on_epoch_end(epoch, logs=logs)
def plot(self, images, training=None, epoch=0):
if training:
writer = self._train_writer
mode, step = "train", self._train_step.read_value()
else:
writer = self._val_writer
mode, step = "test", self._val_step.read_value()
family = "epoch_%03d" % (1 + epoch)
with writer.as_default():
# used to be family kwarg for tf.summary.image name prefix
with tf.name_scope(family):
tf.summary.image(mode, images, step=step, max_outputs=len(images))
def get_dirs_or_files(input_data): def get_dirs_or_files(input_data):
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
@ -471,14 +544,15 @@ def run(_config,
lab_gen = lab_gen.map(_to_categorical) lab_gen = lab_gen.map(_to_categorical)
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True) return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6)) train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config)]
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels) valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch
_log.info("training on %d batches in %d epochs", train_steps, n_epochs) _log.info("training on %d batches in %d epochs", train_steps, n_epochs)
_log.info("validating on %d batches", valdn_steps) _log.info("validating on %d batches", valdn_steps)
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config),
]
if save_interval: if save_interval:
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
model.fit( model.fit(