mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: shuffle tf.data pipelines
This commit is contained in:
parent
1cff937e72
commit
c1d8a72edc
1 changed files with 4 additions and 2 deletions
|
|
@ -660,11 +660,13 @@ def run(_config,
|
|||
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
|
||||
_log.info("validating on %d batches", valdn_steps)
|
||||
|
||||
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), profile_batch=(10, 20)),
|
||||
SaveWeightsAfterSteps(0, dir_output, _config),
|
||||
]
|
||||
if save_interval:
|
||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||
train_gen = train_gen.shuffle(train_steps // 1000, reshuffle_each_iteration=True)
|
||||
valdn_gen = valdn_gen.shuffle(valdn_steps // 10, reshuffle_each_iteration=False)
|
||||
model.fit(
|
||||
train_gen.prefetch(tf.data.AUTOTUNE),
|
||||
steps_per_epoch=train_steps,
|
||||
|
|
@ -731,7 +733,7 @@ def run(_config,
|
|||
if save_interval:
|
||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||
model.fit(
|
||||
train_ds,
|
||||
train_ds.shuffle(200),
|
||||
validation_data=valdn_ds,
|
||||
verbose=1,
|
||||
epochs=n_epochs,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue