diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 92a2f49..63f7717 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -660,11 +660,13 @@ def run(_config, _log.info("training on %d batches in %d epochs", train_steps, n_epochs) _log.info("validating on %d batches", valdn_steps) - callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False), + callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), profile_batch=(10, 20)), SaveWeightsAfterSteps(0, dir_output, _config), ] if save_interval: callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) + train_gen = train_gen.shuffle(train_steps // 1000, reshuffle_each_iteration=True) + valdn_gen = valdn_gen.shuffle(valdn_steps // 10, reshuffle_each_iteration=False) model.fit( train_gen.prefetch(tf.data.AUTOTUNE), steps_per_epoch=train_steps, @@ -731,7 +733,7 @@ def run(_config, if save_interval: callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) model.fit( - train_ds, + train_ds.shuffle(200), validation_data=valdn_ds, verbose=1, epochs=n_epochs,