training.train: more config dependencies…

- make more config_params keys dependent on each other
- re-order accordingly
- in main, initialise them (as kwarg), so sacred actually
  allows overriding them by named config file
This commit is contained in:
Robert Sachunsky 2026-02-05 11:53:19 +01:00
parent 7562317da5
commit 4a65ee0c67

View file

@ -97,7 +97,17 @@ ex = Experiment(save_git_info=False)
@ex.config
def config_params():
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
backbone_type = None # Type of image feature map network backbone. Either a vision transformer alongside a CNN we call "transformer", or only a CNN which we call "nontransformer"
if task in ["segmentation", "binarization", "enhancement"]:
backbone_type = "nontransformer" # Type of image feature map network backbone. Either a vision transformer alongside a CNN we call "transformer", or only a CNN which we call "nontransformer"
if backbone_type == "transformer":
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively.
transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64.
transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]
transformer_layers = 8 # transformer layers. Default value is 8.
transformer_num_heads = 4 # Transformer number of heads. Default value is 4.
transformer_cnn_first = True # We have two types of vision transformers: either the CNN is applied first, followed by the transformer, or reversed.
n_classes = None # Number of classes. In the case of binary classification this should be 2.
n_epochs = 1 # Number of epochs to train.
n_batch = 1 # Number of images per batch at each iteration. (Try as large as fits on VRAM.)
@ -105,8 +115,10 @@ def config_params():
input_width = 224 * 1 # Width of model's input in pixels.
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
learning_rate = 1e-4 # Set the learning rate.
if task in ["segmentation", "binarization"]:
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
elif task == "classification":
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
classification_classes_name = None # Dictionary of classification classes names.
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
@ -163,17 +175,8 @@ def config_params():
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
if backbone_type == "transformer":
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively.
transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64.
transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]
transformer_layers = 8 # transformer layers. Default value is 8.
transformer_num_heads = 4 # Transformer number of heads. Default value is 4.
transformer_cnn_first = True # We have two types of vision transformers: either the CNN is applied first, followed by the transformer, or reversed.
@ex.automain
@ex.main
def run(_config,
_log,
task,
@ -187,27 +190,29 @@ def run(_config,
n_batch,
input_height,
input_width,
is_loss_soft_dice,
weighted_loss,
weight_decay,
learning_rate,
continue_training,
index_start,
dir_of_start_model,
save_interval,
augmentation,
thetha,
backbone_type,
transformer_projection_dim,
transformer_mlp_head_units,
transformer_layers,
transformer_num_heads,
transformer_cnn_first,
transformer_patchsize_x,
transformer_patchsize_y,
transformer_num_patches_xy,
f1_threshold_classification,
classification_classes_name,
# dependent config keys need a default,
# otherwise yields sacred.utils.ConfigAddedError
thetha=None,
is_loss_soft_dice=False,
weighted_loss=False,
index_start=0,
dir_of_start_model=None,
backbone_type=None,
transformer_projection_dim=None,
transformer_mlp_head_units=None,
transformer_layers=None,
transformer_num_heads=None,
transformer_cnn_first=None,
transformer_patchsize_x=None,
transformer_patchsize_y=None,
transformer_num_patches_xy=None,
f1_threshold_classification=None,
classification_classes_name=None,
):
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):