From 093030f503e0032c97260540ad42c671f3f0d6a1 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Thu, 28 May 2026 17:37:45 +0200 Subject: [PATCH] =?UTF-8?q?train/models:=20move=20all=20model=20builders?= =?UTF-8?q?=20to=20`models.get=5Fmodel()`=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - models: add new `get_model()`, passing in Sacred config to capture builder function arguments - train: fewer imports - train: no need to pass `custom_objects` if loading with `compile=False` (and we custom-compile later, anyway) --- src/eynollah/training/models.py | 49 ++++++++++++++++++++ src/eynollah/training/train.py | 81 ++++----------------------------- 2 files changed, 57 insertions(+), 73 deletions(-) diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index eb621c6..83058ee 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -499,3 +499,52 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l return Model((inputs, labels), out) +def get_model(config, logger): + from sacred.config import create_captured_function + + task = config['task'] + if task in ["segmentation", "enhancement", "binarization"]: + if config['backbone_type'] == 'nontransformer': + builder = resnet50_unet + else: + num_patches_x, num_patches_y = config['transformer_num_patches_xy'] + num_patches = num_patches_x * num_patches_y + + if config['transformer_cnn_first']: + builder = vit_resnet50_unet + multiple = 32 + else: + builder = vit_resnet50_unet_transformer_before_cnn + multiple = 1 + + assert config['input_height'] == ( + num_patches_y * config['transformer_patchsize_y'] * multiple), ( + "transformer_patchsize_y or transformer_num_patches_xy height value error: " + "input_height should be equal to " + "(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple) + assert config['input_width'] == ( + num_patches_x * config['transformer_patchsize_x'] * multiple), ( + "transformer_patchsize_x or transformer_num_patches_xy width value error: " + "input_width should be equal to " + "(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple) + assert 0 == (config['transformer_projection_dim'] % + (config['transformer_patchsize_y'] * + config['transformer_patchsize_x'])), ( + "transformer_projection_dim error: " + "The remainder when parameter transformer_projection_dim is divided by " + "(transformer_patchsize_y*transformer_patchsize_x) should be zero") + + config['num_patches'] = num_patches + elif task == "cnn-rnn-ocr": + builder = cnn_rnn_ocr_model + elif task=='classification': + builder = resnet50_classifier + elif task=='reading_order': + builder = machine_based_reading_order_model + else: + raise ValueError("unknown model task '%s'" % task) + + builder = create_captured_function(builder) + builder.config = config + builder.logger = logger + return builder() diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 00ed6ee..2cb42b6 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -17,7 +17,6 @@ from tensorflow.keras.layers import StringLookup from tensorflow.keras.utils import image_dataset_from_directory from tensorflow.keras.backend import one_hot from sacred import Experiment -from sacred.config import create_captured_function import numpy as np import cv2 @@ -32,16 +31,9 @@ from .metrics import ( connected_components_loss, ) from .models import ( - PatchEncoder, - Patches, - machine_based_reading_order_model, - resnet50_classifier, - resnet50_unet, - vit_resnet50_unet, - vit_resnet50_unet_transformer_before_cnn, - cnn_rnn_ocr_model, RESNET50_WEIGHTS_PATH, - RESNET50_WEIGHTS_URL + RESNET50_WEIGHTS_URL, + get_model ) from .utils import ( generate_arrays_from_folder_reading_order, @@ -477,58 +469,12 @@ def run(_config, if task == "enhancement": assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply" assert not weighted_loss, "for enhancement, weighted loss does not apply" + if continue_training: - custom_objects = dict() - if is_loss_soft_dice: - custom_objects.update(soft_dice_loss=soft_dice_loss) - elif weighted_loss: - custom_objects.update(loss=weighted_categorical_crossentropy(weights)) - if backbone_type == 'transformer': - custom_objects.update(PatchEncoder=PatchEncoder, - Patches=Patches) - model = load_model(dir_of_start_model, compile=False, - custom_objects=custom_objects) + model = load_model(dir_of_start_model, compile=False) else: index_start = 0 - if backbone_type == 'nontransformer': - model = resnet50_unet(n_classes, - input_height, - input_width, - task, - weight_decay, - pretraining) - else: - num_patches_x = transformer_num_patches_xy[0] - num_patches_y = transformer_num_patches_xy[1] - num_patches = num_patches_x * num_patches_y - - if transformer_cnn_first: - model_builder = vit_resnet50_unet - multiple = 32 - else: - model_builder = vit_resnet50_unet_transformer_before_cnn - multiple = 1 - - assert input_height == ( - num_patches_y * transformer_patchsize_y * multiple), ( - "transformer_patchsize_y or transformer_num_patches_xy height value error: " - "input_height should be equal to " - "(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple) - assert input_width == ( - num_patches_x * transformer_patchsize_x * multiple), ( - "transformer_patchsize_x or transformer_num_patches_xy width value error: " - "input_width should be equal to " - "(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple) - assert 0 == (transformer_projection_dim % - (transformer_patchsize_y * transformer_patchsize_x)), ( - "transformer_projection_dim error: " - "The remainder when parameter transformer_projection_dim is divided by " - "(transformer_patchsize_y*transformer_patchsize_x) should be zero") - - model_builder = create_captured_function(model_builder) - model_builder.config = _config - model_builder.logger = _log - model = model_builder(num_patches) + model = get_model(_config, _log) assert model is not None #if you want to see the model structure just uncomment model summary. @@ -709,10 +655,7 @@ def run(_config, model = load_model(dir_of_start_model) else: index_start = 0 - model = cnn_rnn_ocr_model(image_height=input_height, - image_width=input_width, - n_classes=n_classes, - max_seq=max_len) + model = get_model(_config, _log) #initial_learning_rate = 1e-4 #decay_steps = int (n_epochs * ( len_dataset / n_batch )) #alpha = 0.01 @@ -774,11 +717,7 @@ def run(_config, model = load_model(dir_of_start_model, compile=False) else: index_start = 0 - model = resnet50_classifier(n_classes, - input_height, - input_width, - weight_decay, - pretraining) + model = get_model(_config, _log) model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate? @@ -830,11 +769,7 @@ def run(_config, model = load_model(dir_of_start_model, compile=False) else: index_start = 0 - model = machine_based_reading_order_model(n_classes, - input_height, - input_width, - weight_decay, - pretraining) + model = get_model(_config, _log) #f1score_tot = [0] model.compile(loss="binary_crossentropy",