mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-31 01:59:27 +02:00
train/models: move all model builders to models.get_model()…
- models: add new `get_model()`, passing in Sacred config to capture builder function arguments - train: fewer imports - train: no need to pass `custom_objects` if loading with `compile=False` (and we custom-compile later, anyway)
This commit is contained in:
parent
faef1967f8
commit
093030f503
2 changed files with 57 additions and 73 deletions
|
|
@ -499,3 +499,52 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
|||
|
||||
return Model((inputs, labels), out)
|
||||
|
||||
def get_model(config, logger):
|
||||
from sacred.config import create_captured_function
|
||||
|
||||
task = config['task']
|
||||
if task in ["segmentation", "enhancement", "binarization"]:
|
||||
if config['backbone_type'] == 'nontransformer':
|
||||
builder = resnet50_unet
|
||||
else:
|
||||
num_patches_x, num_patches_y = config['transformer_num_patches_xy']
|
||||
num_patches = num_patches_x * num_patches_y
|
||||
|
||||
if config['transformer_cnn_first']:
|
||||
builder = vit_resnet50_unet
|
||||
multiple = 32
|
||||
else:
|
||||
builder = vit_resnet50_unet_transformer_before_cnn
|
||||
multiple = 1
|
||||
|
||||
assert config['input_height'] == (
|
||||
num_patches_y * config['transformer_patchsize_y'] * multiple), (
|
||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
|
||||
"input_height should be equal to "
|
||||
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
|
||||
assert config['input_width'] == (
|
||||
num_patches_x * config['transformer_patchsize_x'] * multiple), (
|
||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
|
||||
"input_width should be equal to "
|
||||
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
|
||||
assert 0 == (config['transformer_projection_dim'] %
|
||||
(config['transformer_patchsize_y'] *
|
||||
config['transformer_patchsize_x'])), (
|
||||
"transformer_projection_dim error: "
|
||||
"The remainder when parameter transformer_projection_dim is divided by "
|
||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||
|
||||
config['num_patches'] = num_patches
|
||||
elif task == "cnn-rnn-ocr":
|
||||
builder = cnn_rnn_ocr_model
|
||||
elif task=='classification':
|
||||
builder = resnet50_classifier
|
||||
elif task=='reading_order':
|
||||
builder = machine_based_reading_order_model
|
||||
else:
|
||||
raise ValueError("unknown model task '%s'" % task)
|
||||
|
||||
builder = create_captured_function(builder)
|
||||
builder.config = config
|
||||
builder.logger = logger
|
||||
return builder()
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from tensorflow.keras.layers import StringLookup
|
|||
from tensorflow.keras.utils import image_dataset_from_directory
|
||||
from tensorflow.keras.backend import one_hot
|
||||
from sacred import Experiment
|
||||
from sacred.config import create_captured_function
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
|
@ -32,16 +31,9 @@ from .metrics import (
|
|||
connected_components_loss,
|
||||
)
|
||||
from .models import (
|
||||
PatchEncoder,
|
||||
Patches,
|
||||
machine_based_reading_order_model,
|
||||
resnet50_classifier,
|
||||
resnet50_unet,
|
||||
vit_resnet50_unet,
|
||||
vit_resnet50_unet_transformer_before_cnn,
|
||||
cnn_rnn_ocr_model,
|
||||
RESNET50_WEIGHTS_PATH,
|
||||
RESNET50_WEIGHTS_URL
|
||||
RESNET50_WEIGHTS_URL,
|
||||
get_model
|
||||
)
|
||||
from .utils import (
|
||||
generate_arrays_from_folder_reading_order,
|
||||
|
|
@ -477,58 +469,12 @@ def run(_config,
|
|||
if task == "enhancement":
|
||||
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
|
||||
assert not weighted_loss, "for enhancement, weighted loss does not apply"
|
||||
|
||||
if continue_training:
|
||||
custom_objects = dict()
|
||||
if is_loss_soft_dice:
|
||||
custom_objects.update(soft_dice_loss=soft_dice_loss)
|
||||
elif weighted_loss:
|
||||
custom_objects.update(loss=weighted_categorical_crossentropy(weights))
|
||||
if backbone_type == 'transformer':
|
||||
custom_objects.update(PatchEncoder=PatchEncoder,
|
||||
Patches=Patches)
|
||||
model = load_model(dir_of_start_model, compile=False,
|
||||
custom_objects=custom_objects)
|
||||
model = load_model(dir_of_start_model, compile=False)
|
||||
else:
|
||||
index_start = 0
|
||||
if backbone_type == 'nontransformer':
|
||||
model = resnet50_unet(n_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
task,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
else:
|
||||
num_patches_x = transformer_num_patches_xy[0]
|
||||
num_patches_y = transformer_num_patches_xy[1]
|
||||
num_patches = num_patches_x * num_patches_y
|
||||
|
||||
if transformer_cnn_first:
|
||||
model_builder = vit_resnet50_unet
|
||||
multiple = 32
|
||||
else:
|
||||
model_builder = vit_resnet50_unet_transformer_before_cnn
|
||||
multiple = 1
|
||||
|
||||
assert input_height == (
|
||||
num_patches_y * transformer_patchsize_y * multiple), (
|
||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
|
||||
"input_height should be equal to "
|
||||
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
|
||||
assert input_width == (
|
||||
num_patches_x * transformer_patchsize_x * multiple), (
|
||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
|
||||
"input_width should be equal to "
|
||||
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
|
||||
assert 0 == (transformer_projection_dim %
|
||||
(transformer_patchsize_y * transformer_patchsize_x)), (
|
||||
"transformer_projection_dim error: "
|
||||
"The remainder when parameter transformer_projection_dim is divided by "
|
||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||
|
||||
model_builder = create_captured_function(model_builder)
|
||||
model_builder.config = _config
|
||||
model_builder.logger = _log
|
||||
model = model_builder(num_patches)
|
||||
model = get_model(_config, _log)
|
||||
|
||||
assert model is not None
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
|
|
@ -709,10 +655,7 @@ def run(_config,
|
|||
model = load_model(dir_of_start_model)
|
||||
else:
|
||||
index_start = 0
|
||||
model = cnn_rnn_ocr_model(image_height=input_height,
|
||||
image_width=input_width,
|
||||
n_classes=n_classes,
|
||||
max_seq=max_len)
|
||||
model = get_model(_config, _log)
|
||||
#initial_learning_rate = 1e-4
|
||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||
#alpha = 0.01
|
||||
|
|
@ -774,11 +717,7 @@ def run(_config,
|
|||
model = load_model(dir_of_start_model, compile=False)
|
||||
else:
|
||||
index_start = 0
|
||||
model = resnet50_classifier(n_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
model = get_model(_config, _log)
|
||||
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
|
||||
|
|
@ -830,11 +769,7 @@ def run(_config,
|
|||
model = load_model(dir_of_start_model, compile=False)
|
||||
else:
|
||||
index_start = 0
|
||||
model = machine_based_reading_order_model(n_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
model = get_model(_config, _log)
|
||||
|
||||
#f1score_tot = [0]
|
||||
model.compile(loss="binary_crossentropy",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue