training: fix epoch size calculation

This commit is contained in:
Robert Sachunsky 2026-01-29 03:01:14 +01:00
parent 29a0f19cee
commit d1e8a02fd4

View file

@ -394,17 +394,16 @@ def run(_config,
if save_interval: if save_interval:
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
_log.info("training on %d batches in %d epochs", steps_train = len(os.listdir(dir_flow_train_imgs)) // n_batch # - 1
len(os.listdir(dir_flow_train_imgs)) // n_batch - 1, steps_val = len(os.listdir(dir_flow_eval_imgs)) // n_batch
n_epochs) _log.info("training on %d batches in %d epochs", steps_train, n_epochs)
_log.info("validating on %d batches", _log.info("validating on %d batches", steps_val)
len(os.listdir(dir_flow_eval_imgs)) // n_batch - 1)
model.fit( model.fit(
train_gen, train_gen,
steps_per_epoch=len(os.listdir(dir_flow_train_imgs)) // n_batch - 1, steps_per_epoch=steps_train,
validation_data=val_gen, validation_data=val_gen,
#validation_steps=1, # rs: only one batch?? #validation_steps=1, # rs: only one batch??
validation_steps=len(os.listdir(dir_flow_eval_imgs)) // n_batch - 1, validation_steps=steps_val,
epochs=n_epochs, epochs=n_epochs,
callbacks=callbacks) callbacks=callbacks)