From 3b56fa2a5b56bead190dda896c11cc8e6666f789 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sat, 28 Feb 2026 20:08:10 +0100 Subject: [PATCH] training: plot GT/prediction and metrics before training (commented) --- src/eynollah/training/train.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index f06c35b..ff6865b 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -641,6 +641,33 @@ def run(_config, callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) train_gen = train_gen.shuffle(train_steps // 1000, reshuffle_each_iteration=True) valdn_gen = valdn_gen.shuffle(valdn_steps // 10, reshuffle_each_iteration=False) + # from matplotlib import pyplot as plt + # from tensorflow_addons.image import connected_components + # def plot(x, ytrue): + # ypred = model.call(x) + # gt = plot_layout_tf(x, ytrue) + # dt = plot_layout_tf(x, ypred) + # segtrue = tf.math.argmax(ytrue, axis=-1) + # segpred = tf.math.argmax(ypred, axis=-1) + # cctrue = connected_components(segtrue) + # ccpred = connected_components(segpred) + # cc = connected_components_loss(n_classes-1)(ytrue, ypred) + # sd = soft_dice_loss(ytrue, ypred) + # return gt, dt, cctrue, ccpred, cc, sd + # for gt, dt, gtcc, dtcc, cc, sd in train_gen.take(15).rebatch(1).map(plot).as_numpy_iterator(): + # plt.subplot(2, 2, 1) + # plt.imshow(np.squeeze(gt)) + # plt.title('GT') + # plt.subplot(2, 2, 3) + # plt.imshow(np.squeeze(gtcc)) + # plt.title('GT CC') + # plt.subplot(2, 2, 4) + # plt.imshow(np.squeeze(dtcc)) + # plt.title('prediction CC') + # plt.subplot(2, 2, 2) + # plt.imshow(np.squeeze(dt)) + # plt.title(f'prediction (nCC={cc} soft dice={sd:.3f})') + # plt.show() model.fit( train_gen.prefetch(tf.data.AUTOTUNE), steps_per_epoch=train_steps,