train/models: move all model builders to models.get_model()

- models: add new `get_model()`, passing in Sacred config
  to capture builder function arguments
- train: fewer imports
- train: no need to pass `custom_objects` if loading with
  `compile=False` (and we custom-compile later, anyway)
This commit is contained in:
Robert Sachunsky 2026-05-28 17:37:45 +02:00
parent faef1967f8
commit 093030f503
2 changed files with 57 additions and 73 deletions

View file

@ -499,3 +499,52 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
return Model((inputs, labels), out) return Model((inputs, labels), out)
def get_model(config, logger):
from sacred.config import create_captured_function
task = config['task']
if task in ["segmentation", "enhancement", "binarization"]:
if config['backbone_type'] == 'nontransformer':
builder = resnet50_unet
else:
num_patches_x, num_patches_y = config['transformer_num_patches_xy']
num_patches = num_patches_x * num_patches_y
if config['transformer_cnn_first']:
builder = vit_resnet50_unet
multiple = 32
else:
builder = vit_resnet50_unet_transformer_before_cnn
multiple = 1
assert config['input_height'] == (
num_patches_y * config['transformer_patchsize_y'] * multiple), (
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
"input_height should be equal to "
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
assert config['input_width'] == (
num_patches_x * config['transformer_patchsize_x'] * multiple), (
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
"input_width should be equal to "
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
assert 0 == (config['transformer_projection_dim'] %
(config['transformer_patchsize_y'] *
config['transformer_patchsize_x'])), (
"transformer_projection_dim error: "
"The remainder when parameter transformer_projection_dim is divided by "
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
config['num_patches'] = num_patches
elif task == "cnn-rnn-ocr":
builder = cnn_rnn_ocr_model
elif task=='classification':
builder = resnet50_classifier
elif task=='reading_order':
builder = machine_based_reading_order_model
else:
raise ValueError("unknown model task '%s'" % task)
builder = create_captured_function(builder)
builder.config = config
builder.logger = logger
return builder()

View file

@ -17,7 +17,6 @@ from tensorflow.keras.layers import StringLookup
from tensorflow.keras.utils import image_dataset_from_directory from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.backend import one_hot from tensorflow.keras.backend import one_hot
from sacred import Experiment from sacred import Experiment
from sacred.config import create_captured_function
import numpy as np import numpy as np
import cv2 import cv2
@ -32,16 +31,9 @@ from .metrics import (
connected_components_loss, connected_components_loss,
) )
from .models import ( from .models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH, RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL RESNET50_WEIGHTS_URL,
get_model
) )
from .utils import ( from .utils import (
generate_arrays_from_folder_reading_order, generate_arrays_from_folder_reading_order,
@ -477,58 +469,12 @@ def run(_config,
if task == "enhancement": if task == "enhancement":
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply" assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
assert not weighted_loss, "for enhancement, weighted loss does not apply" assert not weighted_loss, "for enhancement, weighted loss does not apply"
if continue_training: if continue_training:
custom_objects = dict() model = load_model(dir_of_start_model, compile=False)
if is_loss_soft_dice:
custom_objects.update(soft_dice_loss=soft_dice_loss)
elif weighted_loss:
custom_objects.update(loss=weighted_categorical_crossentropy(weights))
if backbone_type == 'transformer':
custom_objects.update(PatchEncoder=PatchEncoder,
Patches=Patches)
model = load_model(dir_of_start_model, compile=False,
custom_objects=custom_objects)
else: else:
index_start = 0 index_start = 0
if backbone_type == 'nontransformer': model = get_model(_config, _log)
model = resnet50_unet(n_classes,
input_height,
input_width,
task,
weight_decay,
pretraining)
else:
num_patches_x = transformer_num_patches_xy[0]
num_patches_y = transformer_num_patches_xy[1]
num_patches = num_patches_x * num_patches_y
if transformer_cnn_first:
model_builder = vit_resnet50_unet
multiple = 32
else:
model_builder = vit_resnet50_unet_transformer_before_cnn
multiple = 1
assert input_height == (
num_patches_y * transformer_patchsize_y * multiple), (
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
"input_height should be equal to "
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
assert input_width == (
num_patches_x * transformer_patchsize_x * multiple), (
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
"input_width should be equal to "
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
assert 0 == (transformer_projection_dim %
(transformer_patchsize_y * transformer_patchsize_x)), (
"transformer_projection_dim error: "
"The remainder when parameter transformer_projection_dim is divided by "
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
model_builder = create_captured_function(model_builder)
model_builder.config = _config
model_builder.logger = _log
model = model_builder(num_patches)
assert model is not None assert model is not None
#if you want to see the model structure just uncomment model summary. #if you want to see the model structure just uncomment model summary.
@ -709,10 +655,7 @@ def run(_config,
model = load_model(dir_of_start_model) model = load_model(dir_of_start_model)
else: else:
index_start = 0 index_start = 0
model = cnn_rnn_ocr_model(image_height=input_height, model = get_model(_config, _log)
image_width=input_width,
n_classes=n_classes,
max_seq=max_len)
#initial_learning_rate = 1e-4 #initial_learning_rate = 1e-4
#decay_steps = int (n_epochs * ( len_dataset / n_batch )) #decay_steps = int (n_epochs * ( len_dataset / n_batch ))
#alpha = 0.01 #alpha = 0.01
@ -774,11 +717,7 @@ def run(_config,
model = load_model(dir_of_start_model, compile=False) model = load_model(dir_of_start_model, compile=False)
else: else:
index_start = 0 index_start = 0
model = resnet50_classifier(n_classes, model = get_model(_config, _log)
input_height,
input_width,
weight_decay,
pretraining)
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate? optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
@ -830,11 +769,7 @@ def run(_config,
model = load_model(dir_of_start_model, compile=False) model = load_model(dir_of_start_model, compile=False)
else: else:
index_start = 0 index_start = 0
model = machine_based_reading_order_model(n_classes, model = get_model(_config, _log)
input_height,
input_width,
weight_decay,
pretraining)
#f1score_tot = [0] #f1score_tot = [0]
model.compile(loss="binary_crossentropy", model.compile(loss="binary_crossentropy",