training: shuffle tf.data pipelines

This commit is contained in:
Robert Sachunsky 2026-02-28 20:04:32 +01:00
parent 1cff937e72
commit c1d8a72edc

View file

@ -660,11 +660,13 @@ def run(_config,
_log.info("training on %d batches in %d epochs", train_steps, n_epochs) _log.info("training on %d batches in %d epochs", train_steps, n_epochs)
_log.info("validating on %d batches", valdn_steps) _log.info("validating on %d batches", valdn_steps)
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False), callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), profile_batch=(10, 20)),
SaveWeightsAfterSteps(0, dir_output, _config), SaveWeightsAfterSteps(0, dir_output, _config),
] ]
if save_interval: if save_interval:
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
train_gen = train_gen.shuffle(train_steps // 1000, reshuffle_each_iteration=True)
valdn_gen = valdn_gen.shuffle(valdn_steps // 10, reshuffle_each_iteration=False)
model.fit( model.fit(
train_gen.prefetch(tf.data.AUTOTUNE), train_gen.prefetch(tf.data.AUTOTUNE),
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
@ -731,7 +733,7 @@ def run(_config,
if save_interval: if save_interval:
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
model.fit( model.fit(
train_ds, train_ds.shuffle(200),
validation_data=valdn_ds, validation_data=valdn_ds,
verbose=1, verbose=1,
epochs=n_epochs, epochs=n_epochs,