diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 9c638ea..1e2ab3e 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -394,17 +394,16 @@ def run(_config, if save_interval: callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) - _log.info("training on %d batches in %d epochs", - len(os.listdir(dir_flow_train_imgs)) // n_batch - 1, - n_epochs) - _log.info("validating on %d batches", - len(os.listdir(dir_flow_eval_imgs)) // n_batch - 1) + steps_train = len(os.listdir(dir_flow_train_imgs)) // n_batch # - 1 + steps_val = len(os.listdir(dir_flow_eval_imgs)) // n_batch + _log.info("training on %d batches in %d epochs", steps_train, n_epochs) + _log.info("validating on %d batches", steps_val) model.fit( train_gen, - steps_per_epoch=len(os.listdir(dir_flow_train_imgs)) // n_batch - 1, + steps_per_epoch=steps_train, validation_data=val_gen, #validation_steps=1, # rs: only one batch?? - validation_steps=len(os.listdir(dir_flow_eval_imgs)) // n_batch - 1, + validation_steps=steps_val, epochs=n_epochs, callbacks=callbacks)