mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-31 01:59:27 +02:00
train/models: move all model builders to models.get_model()…
- models: add new `get_model()`, passing in Sacred config to capture builder function arguments - train: fewer imports - train: no need to pass `custom_objects` if loading with `compile=False` (and we custom-compile later, anyway)
This commit is contained in:
parent
faef1967f8
commit
093030f503
2 changed files with 57 additions and 73 deletions
|
|
@ -499,3 +499,52 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
||||||
|
|
||||||
return Model((inputs, labels), out)
|
return Model((inputs, labels), out)
|
||||||
|
|
||||||
|
def get_model(config, logger):
|
||||||
|
from sacred.config import create_captured_function
|
||||||
|
|
||||||
|
task = config['task']
|
||||||
|
if task in ["segmentation", "enhancement", "binarization"]:
|
||||||
|
if config['backbone_type'] == 'nontransformer':
|
||||||
|
builder = resnet50_unet
|
||||||
|
else:
|
||||||
|
num_patches_x, num_patches_y = config['transformer_num_patches_xy']
|
||||||
|
num_patches = num_patches_x * num_patches_y
|
||||||
|
|
||||||
|
if config['transformer_cnn_first']:
|
||||||
|
builder = vit_resnet50_unet
|
||||||
|
multiple = 32
|
||||||
|
else:
|
||||||
|
builder = vit_resnet50_unet_transformer_before_cnn
|
||||||
|
multiple = 1
|
||||||
|
|
||||||
|
assert config['input_height'] == (
|
||||||
|
num_patches_y * config['transformer_patchsize_y'] * multiple), (
|
||||||
|
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
|
||||||
|
"input_height should be equal to "
|
||||||
|
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
|
||||||
|
assert config['input_width'] == (
|
||||||
|
num_patches_x * config['transformer_patchsize_x'] * multiple), (
|
||||||
|
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
|
||||||
|
"input_width should be equal to "
|
||||||
|
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
|
||||||
|
assert 0 == (config['transformer_projection_dim'] %
|
||||||
|
(config['transformer_patchsize_y'] *
|
||||||
|
config['transformer_patchsize_x'])), (
|
||||||
|
"transformer_projection_dim error: "
|
||||||
|
"The remainder when parameter transformer_projection_dim is divided by "
|
||||||
|
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||||
|
|
||||||
|
config['num_patches'] = num_patches
|
||||||
|
elif task == "cnn-rnn-ocr":
|
||||||
|
builder = cnn_rnn_ocr_model
|
||||||
|
elif task=='classification':
|
||||||
|
builder = resnet50_classifier
|
||||||
|
elif task=='reading_order':
|
||||||
|
builder = machine_based_reading_order_model
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown model task '%s'" % task)
|
||||||
|
|
||||||
|
builder = create_captured_function(builder)
|
||||||
|
builder.config = config
|
||||||
|
builder.logger = logger
|
||||||
|
return builder()
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from tensorflow.keras.layers import StringLookup
|
||||||
from tensorflow.keras.utils import image_dataset_from_directory
|
from tensorflow.keras.utils import image_dataset_from_directory
|
||||||
from tensorflow.keras.backend import one_hot
|
from tensorflow.keras.backend import one_hot
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
from sacred.config import create_captured_function
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
@ -32,16 +31,9 @@ from .metrics import (
|
||||||
connected_components_loss,
|
connected_components_loss,
|
||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
PatchEncoder,
|
|
||||||
Patches,
|
|
||||||
machine_based_reading_order_model,
|
|
||||||
resnet50_classifier,
|
|
||||||
resnet50_unet,
|
|
||||||
vit_resnet50_unet,
|
|
||||||
vit_resnet50_unet_transformer_before_cnn,
|
|
||||||
cnn_rnn_ocr_model,
|
|
||||||
RESNET50_WEIGHTS_PATH,
|
RESNET50_WEIGHTS_PATH,
|
||||||
RESNET50_WEIGHTS_URL
|
RESNET50_WEIGHTS_URL,
|
||||||
|
get_model
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
generate_arrays_from_folder_reading_order,
|
generate_arrays_from_folder_reading_order,
|
||||||
|
|
@ -477,58 +469,12 @@ def run(_config,
|
||||||
if task == "enhancement":
|
if task == "enhancement":
|
||||||
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
|
assert not is_loss_soft_dice, "for enhancement, soft_dice loss does not apply"
|
||||||
assert not weighted_loss, "for enhancement, weighted loss does not apply"
|
assert not weighted_loss, "for enhancement, weighted loss does not apply"
|
||||||
|
|
||||||
if continue_training:
|
if continue_training:
|
||||||
custom_objects = dict()
|
model = load_model(dir_of_start_model, compile=False)
|
||||||
if is_loss_soft_dice:
|
|
||||||
custom_objects.update(soft_dice_loss=soft_dice_loss)
|
|
||||||
elif weighted_loss:
|
|
||||||
custom_objects.update(loss=weighted_categorical_crossentropy(weights))
|
|
||||||
if backbone_type == 'transformer':
|
|
||||||
custom_objects.update(PatchEncoder=PatchEncoder,
|
|
||||||
Patches=Patches)
|
|
||||||
model = load_model(dir_of_start_model, compile=False,
|
|
||||||
custom_objects=custom_objects)
|
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
if backbone_type == 'nontransformer':
|
model = get_model(_config, _log)
|
||||||
model = resnet50_unet(n_classes,
|
|
||||||
input_height,
|
|
||||||
input_width,
|
|
||||||
task,
|
|
||||||
weight_decay,
|
|
||||||
pretraining)
|
|
||||||
else:
|
|
||||||
num_patches_x = transformer_num_patches_xy[0]
|
|
||||||
num_patches_y = transformer_num_patches_xy[1]
|
|
||||||
num_patches = num_patches_x * num_patches_y
|
|
||||||
|
|
||||||
if transformer_cnn_first:
|
|
||||||
model_builder = vit_resnet50_unet
|
|
||||||
multiple = 32
|
|
||||||
else:
|
|
||||||
model_builder = vit_resnet50_unet_transformer_before_cnn
|
|
||||||
multiple = 1
|
|
||||||
|
|
||||||
assert input_height == (
|
|
||||||
num_patches_y * transformer_patchsize_y * multiple), (
|
|
||||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: "
|
|
||||||
"input_height should be equal to "
|
|
||||||
"(transformer_num_patches_xy height value * transformer_patchsize_y * %d)" % multiple)
|
|
||||||
assert input_width == (
|
|
||||||
num_patches_x * transformer_patchsize_x * multiple), (
|
|
||||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: "
|
|
||||||
"input_width should be equal to "
|
|
||||||
"(transformer_num_patches_xy width value * transformer_patchsize_x * %d)" % multiple)
|
|
||||||
assert 0 == (transformer_projection_dim %
|
|
||||||
(transformer_patchsize_y * transformer_patchsize_x)), (
|
|
||||||
"transformer_projection_dim error: "
|
|
||||||
"The remainder when parameter transformer_projection_dim is divided by "
|
|
||||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
|
||||||
|
|
||||||
model_builder = create_captured_function(model_builder)
|
|
||||||
model_builder.config = _config
|
|
||||||
model_builder.logger = _log
|
|
||||||
model = model_builder(num_patches)
|
|
||||||
|
|
||||||
assert model is not None
|
assert model is not None
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
|
|
@ -709,10 +655,7 @@ def run(_config,
|
||||||
model = load_model(dir_of_start_model)
|
model = load_model(dir_of_start_model)
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
model = cnn_rnn_ocr_model(image_height=input_height,
|
model = get_model(_config, _log)
|
||||||
image_width=input_width,
|
|
||||||
n_classes=n_classes,
|
|
||||||
max_seq=max_len)
|
|
||||||
#initial_learning_rate = 1e-4
|
#initial_learning_rate = 1e-4
|
||||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||||
#alpha = 0.01
|
#alpha = 0.01
|
||||||
|
|
@ -774,11 +717,7 @@ def run(_config,
|
||||||
model = load_model(dir_of_start_model, compile=False)
|
model = load_model(dir_of_start_model, compile=False)
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
model = resnet50_classifier(n_classes,
|
model = get_model(_config, _log)
|
||||||
input_height,
|
|
||||||
input_width,
|
|
||||||
weight_decay,
|
|
||||||
pretraining)
|
|
||||||
|
|
||||||
model.compile(loss='categorical_crossentropy',
|
model.compile(loss='categorical_crossentropy',
|
||||||
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
|
optimizer=Adam(learning_rate=0.001), # rs: why not learning_rate?
|
||||||
|
|
@ -830,11 +769,7 @@ def run(_config,
|
||||||
model = load_model(dir_of_start_model, compile=False)
|
model = load_model(dir_of_start_model, compile=False)
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
model = machine_based_reading_order_model(n_classes,
|
model = get_model(_config, _log)
|
||||||
input_height,
|
|
||||||
input_width,
|
|
||||||
weight_decay,
|
|
||||||
pretraining)
|
|
||||||
|
|
||||||
#f1score_tot = [0]
|
#f1score_tot = [0]
|
||||||
model.compile(loss="binary_crossentropy",
|
model.compile(loss="binary_crossentropy",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue