mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training: shuffle tf.data pipelines
This commit is contained in:
parent
1cff937e72
commit
c1d8a72edc
1 changed files with 4 additions and 2 deletions
|
|
@ -660,11 +660,13 @@ def run(_config,
|
||||||
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
|
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
|
||||||
_log.info("validating on %d batches", valdn_steps)
|
_log.info("validating on %d batches", valdn_steps)
|
||||||
|
|
||||||
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), write_graph=False),
|
callbacks = [TensorBoardPlotter(os.path.join(dir_output, 'logs'), profile_batch=(10, 20)),
|
||||||
SaveWeightsAfterSteps(0, dir_output, _config),
|
SaveWeightsAfterSteps(0, dir_output, _config),
|
||||||
]
|
]
|
||||||
if save_interval:
|
if save_interval:
|
||||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||||
|
train_gen = train_gen.shuffle(train_steps // 1000, reshuffle_each_iteration=True)
|
||||||
|
valdn_gen = valdn_gen.shuffle(valdn_steps // 10, reshuffle_each_iteration=False)
|
||||||
model.fit(
|
model.fit(
|
||||||
train_gen.prefetch(tf.data.AUTOTUNE),
|
train_gen.prefetch(tf.data.AUTOTUNE),
|
||||||
steps_per_epoch=train_steps,
|
steps_per_epoch=train_steps,
|
||||||
|
|
@ -731,7 +733,7 @@ def run(_config,
|
||||||
if save_interval:
|
if save_interval:
|
||||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||||
model.fit(
|
model.fit(
|
||||||
train_ds,
|
train_ds.shuffle(200),
|
||||||
validation_data=valdn_ds,
|
validation_data=valdn_ds,
|
||||||
verbose=1,
|
verbose=1,
|
||||||
epochs=n_epochs,
|
epochs=n_epochs,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue