diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index ecf70b4..73d5e0b 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -275,6 +275,9 @@ def run(_config, run configured experiment via sacred """ + if continue_training: + assert n_epochs > index_start, "with continue_training, n_epochs must be greater than index_start" + if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH): _log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH) download_file(RESNET50_WEIGHTS_URL, RESNET50_WEIGHTS_PATH)