mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: plot predictions to TB logs along with training/testing
This commit is contained in:
parent
56833b3f55
commit
18607e0f48
1 changed files with 76 additions and 2 deletions
|
|
@ -74,6 +74,79 @@ def configuration():
|
||||||
except:
|
except:
|
||||||
print("no GPU device available", file=sys.stderr)
|
print("no GPU device available", file=sys.stderr)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def plot_layout_tf(in_: tf.Tensor, out:tf.Tensor) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Implements training.inference.SBBPredict.visualize_model_output for TF
|
||||||
|
(effectively plotting the layout segmentation map on the input image).
|
||||||
|
|
||||||
|
In doing so, also converts:
|
||||||
|
- from Eynollah's BGR/float on the input side
|
||||||
|
- to std RGB/int format on the output side
|
||||||
|
"""
|
||||||
|
# in_: [B, H, W, 3] (BGR float)
|
||||||
|
image = in_[..., ::-1] * 255
|
||||||
|
# out: [B, H, W, C]
|
||||||
|
lab = tf.math.argmax(out, axis=-1)
|
||||||
|
# lab: [B, H, W]
|
||||||
|
colors = tf.constant([[255, 255, 255],
|
||||||
|
[255, 0, 0],
|
||||||
|
[255, 125, 0],
|
||||||
|
[255, 0, 125],
|
||||||
|
[125, 125, 125],
|
||||||
|
[125, 125, 0],
|
||||||
|
[0, 125, 255],
|
||||||
|
[0, 125, 0],
|
||||||
|
[125, 125, 125],
|
||||||
|
[0, 125, 255],
|
||||||
|
[125, 0, 125],
|
||||||
|
[0, 255, 0],
|
||||||
|
[0, 0, 255],
|
||||||
|
[0, 255, 255],
|
||||||
|
[255, 125, 125],
|
||||||
|
[255, 0, 255]])
|
||||||
|
layout = tf.gather(colors, lab)
|
||||||
|
# layout: [B, H, W, 3]
|
||||||
|
image = tf.cast(image, tf.float32)
|
||||||
|
layout = tf.cast(layout, tf.float32)
|
||||||
|
#weighted = image * 0.5 + layout * 0.1 (too dark)
|
||||||
|
weighted = image * 0.9 + layout * 0.1
|
||||||
|
return tf.cast(weighted, tf.uint8)
|
||||||
|
|
||||||
|
# plot predictions on train and test set during every epoch
|
||||||
|
class TensorBoardPlotter(TensorBoard):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model_call = None
|
||||||
|
def on_epoch_begin(self, epoch, logs=None):
|
||||||
|
super().on_epoch_begin(epoch, logs=logs)
|
||||||
|
self.model_call = self.model.call
|
||||||
|
@tf.function
|
||||||
|
def new_call(inputs, **kwargs):
|
||||||
|
outputs = self.model_call(inputs, **kwargs)
|
||||||
|
images = plot_layout_tf(inputs, outputs)
|
||||||
|
self.plot(images, training=kwargs.get('training', None), epoch=epoch)
|
||||||
|
return outputs
|
||||||
|
self.model.call = new_call
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
# re-instate (so ModelCheckpoint does not see our override call)
|
||||||
|
self.model.call = self.model_call
|
||||||
|
# force rebuild of tf.function (so Python binding for epoch gets re-evaluated)
|
||||||
|
self.model.train_function = self.model.make_train_function(True)
|
||||||
|
self.model.test_function = self.model.make_test_function(True)
|
||||||
|
super().on_epoch_end(epoch, logs=logs)
|
||||||
|
def plot(self, images, training=None, epoch=0):
|
||||||
|
if training:
|
||||||
|
writer = self._train_writer
|
||||||
|
mode, step = "train", self._train_step.read_value()
|
||||||
|
else:
|
||||||
|
writer = self._val_writer
|
||||||
|
mode, step = "test", self._val_step.read_value()
|
||||||
|
family = "epoch_%03d" % (1 + epoch)
|
||||||
|
with writer.as_default():
|
||||||
|
# used to be family kwarg for tf.summary.image name prefix
|
||||||
|
with tf.name_scope(family):
|
||||||
|
tf.summary.image(mode, images, step=step, max_outputs=len(images))
|
||||||
|
|
||||||
def get_dirs_or_files(input_data):
|
def get_dirs_or_files(input_data):
|
||||||
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
||||||
|
|
@ -471,14 +544,15 @@ def run(_config,
|
||||||
lab_gen = lab_gen.map(_to_categorical)
|
lab_gen = lab_gen.map(_to_categorical)
|
||||||
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
||||||
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
||||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
|
||||||
SaveWeightsAfterSteps(0, dir_output, _config)]
|
|
||||||
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
||||||
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
|
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
|
||||||
valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch
|
valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch
|
||||||
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
|
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
|
||||||
_log.info("validating on %d batches", valdn_steps)
|
_log.info("validating on %d batches", valdn_steps)
|
||||||
|
|
||||||
|
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||||
|
SaveWeightsAfterSteps(0, dir_output, _config),
|
||||||
|
]
|
||||||
if save_interval:
|
if save_interval:
|
||||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||||
model.fit(
|
model.fit(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue