training: plot GT/prediction and metrics before training (commented)

This commit is contained in:
Robert Sachunsky 2026-02-28 20:08:10 +01:00
parent e47653f684
commit 3b56fa2a5b

View file

@ -641,6 +641,33 @@ def run(_config,
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
train_gen = train_gen.shuffle(train_steps // 1000, reshuffle_each_iteration=True)
valdn_gen = valdn_gen.shuffle(valdn_steps // 10, reshuffle_each_iteration=False)
# from matplotlib import pyplot as plt
# from tensorflow_addons.image import connected_components
# def plot(x, ytrue):
# ypred = model.call(x)
# gt = plot_layout_tf(x, ytrue)
# dt = plot_layout_tf(x, ypred)
# segtrue = tf.math.argmax(ytrue, axis=-1)
# segpred = tf.math.argmax(ypred, axis=-1)
# cctrue = connected_components(segtrue)
# ccpred = connected_components(segpred)
# cc = connected_components_loss(n_classes-1)(ytrue, ypred)
# sd = soft_dice_loss(ytrue, ypred)
# return gt, dt, cctrue, ccpred, cc, sd
# for gt, dt, gtcc, dtcc, cc, sd in train_gen.take(15).rebatch(1).map(plot).as_numpy_iterator():
# plt.subplot(2, 2, 1)
# plt.imshow(np.squeeze(gt))
# plt.title('GT')
# plt.subplot(2, 2, 3)
# plt.imshow(np.squeeze(gtcc))
# plt.title('GT CC')
# plt.subplot(2, 2, 4)
# plt.imshow(np.squeeze(dtcc))
# plt.title('prediction CC')
# plt.subplot(2, 2, 2)
# plt.imshow(np.squeeze(dt))
# plt.title(f'prediction (nCC={cc} soft dice={sd:.3f})')
# plt.show()
model.fit(
train_gen.prefetch(tf.data.AUTOTUNE),
steps_per_epoch=train_steps,