mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.train: more config dependencies…
- make more config_params keys dependent on each other - re-order accordingly - in main, initialise them (as kwarg), so sacred actually allows overriding them by named config file
This commit is contained in:
parent
7562317da5
commit
4a65ee0c67
1 changed files with 36 additions and 31 deletions
|
|
@ -97,7 +97,17 @@ ex = Experiment(save_git_info=False)
|
|||
@ex.config
|
||||
def config_params():
|
||||
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
|
||||
backbone_type = None # Type of image feature map network backbone. Either a vision transformer alongside a CNN we call "transformer", or only a CNN which we call "nontransformer"
|
||||
if task in ["segmentation", "binarization", "enhancement"]:
|
||||
backbone_type = "nontransformer" # Type of image feature map network backbone. Either a vision transformer alongside a CNN we call "transformer", or only a CNN which we call "nontransformer"
|
||||
if backbone_type == "transformer":
|
||||
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
|
||||
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
|
||||
transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively.
|
||||
transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64.
|
||||
transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]
|
||||
transformer_layers = 8 # transformer layers. Default value is 8.
|
||||
transformer_num_heads = 4 # Transformer number of heads. Default value is 4.
|
||||
transformer_cnn_first = True # We have two types of vision transformers: either the CNN is applied first, followed by the transformer, or reversed.
|
||||
n_classes = None # Number of classes. In the case of binary classification this should be 2.
|
||||
n_epochs = 1 # Number of epochs to train.
|
||||
n_batch = 1 # Number of images per batch at each iteration. (Try as large as fits on VRAM.)
|
||||
|
|
@ -105,10 +115,12 @@ def config_params():
|
|||
input_width = 224 * 1 # Width of model's input in pixels.
|
||||
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
|
||||
learning_rate = 1e-4 # Set the learning rate.
|
||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||
classification_classes_name = None # Dictionary of classification classes names.
|
||||
if task in ["segmentation", "binarization"]:
|
||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||
elif task == "classification":
|
||||
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||
classification_classes_name = None # Dictionary of classification classes names.
|
||||
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
|
||||
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
|
||||
if augmentation:
|
||||
|
|
@ -163,17 +175,8 @@ def config_params():
|
|||
dir_of_start_model = '' # Directory of model checkpoint to load to continue training. (E.g. if you already trained for 3 epochs, set "dir_of_start_model=dir_output/model_03".)
|
||||
index_start = 0 # Epoch counter initial value to continue training. (E.g. if you already trained for 3 epochs, set "index_start=3" to continue naming checkpoints model_04, model_05 etc.)
|
||||
data_is_provided = False # Whether the preprocessed input data (subdirectories "images" and "labels" in both subdirectories "train" and "eval" of "dir_output") has already been generated (in the first epoch of a previous run).
|
||||
if backbone_type == "transformer":
|
||||
transformer_patchsize_x = None # Patch size of vision transformer patches in x direction.
|
||||
transformer_patchsize_y = None # Patch size of vision transformer patches in y direction.
|
||||
transformer_num_patches_xy = None # Number of patches for vision transformer in x and y direction respectively.
|
||||
transformer_projection_dim = 64 # Transformer projection dimension. Default value is 64.
|
||||
transformer_mlp_head_units = [128, 64] # Transformer Multilayer Perceptron (MLP) head units. Default value is [128, 64]
|
||||
transformer_layers = 8 # transformer layers. Default value is 8.
|
||||
transformer_num_heads = 4 # Transformer number of heads. Default value is 4.
|
||||
transformer_cnn_first = True # We have two types of vision transformers: either the CNN is applied first, followed by the transformer, or reversed.
|
||||
|
||||
@ex.automain
|
||||
@ex.main
|
||||
def run(_config,
|
||||
_log,
|
||||
task,
|
||||
|
|
@ -187,27 +190,29 @@ def run(_config,
|
|||
n_batch,
|
||||
input_height,
|
||||
input_width,
|
||||
is_loss_soft_dice,
|
||||
weighted_loss,
|
||||
weight_decay,
|
||||
learning_rate,
|
||||
continue_training,
|
||||
index_start,
|
||||
dir_of_start_model,
|
||||
save_interval,
|
||||
augmentation,
|
||||
thetha,
|
||||
backbone_type,
|
||||
transformer_projection_dim,
|
||||
transformer_mlp_head_units,
|
||||
transformer_layers,
|
||||
transformer_num_heads,
|
||||
transformer_cnn_first,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_num_patches_xy,
|
||||
f1_threshold_classification,
|
||||
classification_classes_name,
|
||||
# dependent config keys need a default,
|
||||
# otherwise yields sacred.utils.ConfigAddedError
|
||||
thetha=None,
|
||||
is_loss_soft_dice=False,
|
||||
weighted_loss=False,
|
||||
index_start=0,
|
||||
dir_of_start_model=None,
|
||||
backbone_type=None,
|
||||
transformer_projection_dim=None,
|
||||
transformer_mlp_head_units=None,
|
||||
transformer_layers=None,
|
||||
transformer_num_heads=None,
|
||||
transformer_cnn_first=None,
|
||||
transformer_patchsize_x=None,
|
||||
transformer_patchsize_y=None,
|
||||
transformer_num_patches_xy=None,
|
||||
f1_threshold_classification=None,
|
||||
classification_classes_name=None,
|
||||
):
|
||||
|
||||
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue