training.train: more config dependencies…

- make more config_params keys dependent on each other
- re-order accordingly
- in main, initialise them (as kwarg), so sacred actually
  allows overriding them by named config file
This commit is contained in:
Robert Sachunsky 2026-02-05 11:53:19 +01:00
parent 7562317da5
commit 4a65ee0c67

View file

@ -97,7 +97,17 @@ ex = Experiment(save_git_info=False)
@ex.config
def config_params():
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
backbone_type = None # Type of image feature map network backbone. Either a vision transformer alongside a CNN we call "transformer", or only a CNN which we call "nontransformer"
if task in ["segmentation", "binarization", "enhancement"]:
backbone_type = "nontransformer" # Type of image feature map network backbone. Either a vision transformer alongside a CNN we call "transformer", or only a CNN which we call "nontransformer"
if backbone_type == "transformer":
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively.
transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64.
transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]
transformer_layers = 8 # transformer layers. Default value is 8.
transformer_num_heads = 4 # Transformer number of heads. Default value is 4.
transformer_cnn_first = True # We have two types of vision transformers: either the CNN is applied first, followed by the transformer, or reversed.
n_classes = None # Number of classes. In the case of binary classification this should be 2.
n_epochs = 1 # Number of epochs to train.
n_batch = 1 # Number of images per batch at each iteration. (Try as large as fits on VRAM.)
@ -105,10 +115,12 @@ def config_params():
input_width = 224 * 1 # Width of model's input in pixels.
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
learning_rate = 1e-4 # Set the learning rate.
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
classification_classes_name = None # Dictionary of classification classes names.
if task in ["segmentation", "binarization"]:
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
elif task == "classification":
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
classification_classes_name = None # Dictionary of classification classes names.
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
if augmentation:
@ -163,17 +175,8 @@ def config_params():
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
if backbone_type == "transformer":
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively.
transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64.
transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]
transformer_layers = 8 # transformer layers. Default value is 8.
transformer_num_heads = 4 # Transformer number of heads. Default value is 4.
transformer_cnn_first = True # We have two types of vision transformers: either the CNN is applied first, followed by the transformer, or reversed.
@ex.automain
@ex.main
def run(_config,
_log,
task,
@ -187,27 +190,29 @@ def run(_config,
n_batch,
input_height,
input_width,
is_loss_soft_dice,
weighted_loss,
weight_decay,
learning_rate,
continue_training,
index_start,
dir_of_start_model,
save_interval,
augmentation,
thetha,
backbone_type,
transformer_projection_dim,
transformer_mlp_head_units,
transformer_layers,
transformer_num_heads,
transformer_cnn_first,
transformer_patchsize_x,
transformer_patchsize_y,
transformer_num_patches_xy,
f1_threshold_classification,
classification_classes_name,
# dependent config keys need a default,
# otherwise yields sacred.utils.ConfigAddedError
thetha=None,
is_loss_soft_dice=False,
weighted_loss=False,
index_start=0,
dir_of_start_model=None,
backbone_type=None,
transformer_projection_dim=None,
transformer_mlp_head_units=None,
transformer_layers=None,
transformer_num_heads=None,
transformer_cnn_first=None,
transformer_patchsize_x=None,
transformer_patchsize_y=None,
transformer_num_patches_xy=None,
f1_threshold_classification=None,
classification_classes_name=None,
):
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):