diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 4d0b317..5305ee3 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -74,6 +74,79 @@ def configuration(): except: print("no GPU device available", file=sys.stderr) +@tf.function +def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor: + """ + Implements training.inference.SBBPredict.visualize_model_output for TF + (effectively plotting the layout segmentation map on the input image). + + In doing so, also converts: + - from Eynollah's BGR/float on the input side + - to std RGB/int format on the output side + """ + # in_: [B, H, W, 3] (BGR float) + image = in_[..., ::-1] * 255 + # out: [B, H, W, C] + lab = tf.math.argmax(out, axis=-1) + # lab: [B, H, W] + colors = tf.constant([[255, 255, 255], + [255, 0, 0], + [255, 125, 0], + [255, 0, 125], + [125, 125, 125], + [125, 125, 0], + [0, 125, 255], + [0, 125, 0], + [125, 125, 125], + [0, 125, 255], + [125, 0, 125], + [0, 255, 0], + [0, 0, 255], + [0, 255, 255], + [255, 125, 125], + [255, 0, 255]]) + layout = tf.gather(colors, lab) + # layout: [B, H, W, 3] + image = tf.cast(image, tf.float32) + layout = tf.cast(layout, tf.float32) + #weighted = image * 0.5 + layout * 0.1 (too dark) + weighted = image * 0.9 + layout * 0.1 + return tf.cast(weighted, tf.uint8) + +# plot predictions on train and test set during every epoch +class TensorBoardPlotter(TensorBoard): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_call = None + def on_epoch_begin(self, epoch, logs=None): + super().on_epoch_begin(epoch, logs=logs) + self.model_call = self.model.call + @tf.function + def new_call(inputs, **kwargs): + outputs = self.model_call(inputs, **kwargs) + images = plot_layout_tf(inputs, outputs) + self.plot(images, training=kwargs.get('training', None), epoch=epoch) + return outputs + self.model.call = new_call + def on_epoch_end(self, epoch, logs=None): + # re-instate (so ModelCheckpoint does not see our override call) + self.model.call = self.model_call + # force rebuild of tf.function (so Python binding for epoch gets re-evaluated) + self.model.train_function = self.model.make_train_function(True) + self.model.test_function = self.model.make_test_function(True) + super().on_epoch_end(epoch, logs=logs) + def plot(self, images, training=None, epoch=0): + if training: + writer = self._train_writer + mode, step = "train", self._train_step.read_value() + else: + writer = self._val_writer + mode, step = "test", self._val_step.read_value() + family = "epoch_%03d" % (1 + epoch) + with writer.as_default(): + # used to be family kwarg for tf.summary.image name prefix + with tf.name_scope(family): + tf.summary.image(mode, images, step=step, max_outputs=len(images)) def get_dirs_or_files(input_data): image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') @@ -471,14 +544,15 @@ def run(_config, lab_gen = lab_gen.map(_to_categorical) return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True) train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6)) - callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), - SaveWeightsAfterSteps(0, dir_output, _config)] valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels) train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch _log.info("training on %d batches in %d epochs", train_steps, n_epochs) _log.info("validating on %d batches", valdn_steps) + callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False), + SaveWeightsAfterSteps(0, dir_output, _config), + ] if save_interval: callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) model.fit(