mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: plot GT/prediction and metrics before training (commented)
This commit is contained in:
parent
e47653f684
commit
3b56fa2a5b
1 changed files with 27 additions and 0 deletions
|
|
@ -641,6 +641,33 @@ def run(_config,
|
|||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||
train_gen = train_gen.shuffle(train_steps // 1000, reshuffle_each_iteration=True)
|
||||
valdn_gen = valdn_gen.shuffle(valdn_steps // 10, reshuffle_each_iteration=False)
|
||||
# from matplotlib import pyplot as plt
|
||||
# from tensorflow_addons.image import connected_components
|
||||
# def plot(x, ytrue):
|
||||
# ypred = model.call(x)
|
||||
# gt = plot_layout_tf(x, ytrue)
|
||||
# dt = plot_layout_tf(x, ypred)
|
||||
# segtrue = tf.math.argmax(ytrue, axis=-1)
|
||||
# segpred = tf.math.argmax(ypred, axis=-1)
|
||||
# cctrue = connected_components(segtrue)
|
||||
# ccpred = connected_components(segpred)
|
||||
# cc = connected_components_loss(n_classes-1)(ytrue, ypred)
|
||||
# sd = soft_dice_loss(ytrue, ypred)
|
||||
# return gt, dt, cctrue, ccpred, cc, sd
|
||||
# for gt, dt, gtcc, dtcc, cc, sd in train_gen.take(15).rebatch(1).map(plot).as_numpy_iterator():
|
||||
# plt.subplot(2, 2, 1)
|
||||
# plt.imshow(np.squeeze(gt))
|
||||
# plt.title('GT')
|
||||
# plt.subplot(2, 2, 3)
|
||||
# plt.imshow(np.squeeze(gtcc))
|
||||
# plt.title('GT CC')
|
||||
# plt.subplot(2, 2, 4)
|
||||
# plt.imshow(np.squeeze(dtcc))
|
||||
# plt.title('prediction CC')
|
||||
# plt.subplot(2, 2, 2)
|
||||
# plt.imshow(np.squeeze(dt))
|
||||
# plt.title(f'prediction (nCC={cc} soft dice={sd:.3f})')
|
||||
# plt.show()
|
||||
model.fit(
|
||||
train_gen.prefetch(tf.data.AUTOTUNE),
|
||||
steps_per_epoch=train_steps,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue