From c1d8a72edc3125159396ccca6db45ff8a69c06de Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sat, 28 Feb 2026 20:04:32 +0100 Subject: [PATCH] training: shuffle tf.data pipelines --- src/eynollah/training/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 92a2f49..63f7717 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -660,11 +660,13 @@ def run(_config, _log.info("training on %d batches in %d epochs", train_steps, n_epochs) _log.info("validating on %d batches", valdn_steps) - callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False), + callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), profile_batch=(10, 20)), SaveWeightsAfterSteps(0, dir_output, _config), ] if save_interval: callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) + train_gen = train_gen.shuffle(train_steps // 1000, reshuffle_each_iteration=True) + valdn_gen = valdn_gen.shuffle(valdn_steps // 10, reshuffle_each_iteration=False) model.fit( train_gen.prefetch(tf.data.AUTOTUNE), steps_per_epoch=train_steps, @@ -731,7 +733,7 @@ def run(_config, if save_interval: callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) model.fit( - train_ds, + train_ds.shuffle(200), validation_data=valdn_ds, verbose=1, epochs=n_epochs,