mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.train: simplify config args for model builder
This commit is contained in:
parent
4a65ee0c67
commit
5c7801a1d6
2 changed files with 63 additions and 37 deletions
|
|
@ -400,9 +400,21 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
|||
return model
|
||||
|
||||
|
||||
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
if mlp_head_units is None:
|
||||
mlp_head_units = [128, 64]
|
||||
def vit_resnet50_unet(num_patches,
|
||||
n_classes,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_mlp_head_units=None,
|
||||
transformer_layers=8,
|
||||
transformer_num_heads=4,
|
||||
transformer_projection_dim=64,
|
||||
input_height=224,
|
||||
input_width=224,
|
||||
task="segmentation",
|
||||
weight_decay=1e-6,
|
||||
pretraining=False):
|
||||
if transformer_mlp_head_units is None:
|
||||
transformer_mlp_head_units = [128, 64]
|
||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||
|
||||
#transformer_units = [
|
||||
|
|
@ -449,30 +461,30 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
|||
|
||||
#num_patches = x.shape[1]*x.shape[2]
|
||||
|
||||
#patch_size_y = input_height / x.shape[1]
|
||||
#patch_size_x = input_width / x.shape[2]
|
||||
#patch_size = patch_size_x * patch_size_y
|
||||
patches = Patches(patch_size_x, patch_size_y)(x)
|
||||
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
||||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = layers.MultiHeadAttention(
|
||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = layers.Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = layers.Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||
encoded_patches = tf.reshape(encoded_patches,
|
||||
[-1, x.shape[1], x.shape[2],
|
||||
transformer_projection_dim // (transformer_patchsize_x *
|
||||
transformer_patchsize_y)])
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
|
|
@ -524,9 +536,21 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
|||
|
||||
return model
|
||||
|
||||
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
if mlp_head_units is None:
|
||||
mlp_head_units = [128, 64]
|
||||
def vit_resnet50_unet_transformer_before_cnn(num_patches,
|
||||
n_classes,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_mlp_head_units=None,
|
||||
transformer_layers=8,
|
||||
transformer_num_heads=4,
|
||||
transformer_projection_dim=64,
|
||||
input_height=224,
|
||||
input_width=224,
|
||||
task="segmentation",
|
||||
weight_decay=1e-6,
|
||||
pretraining=False):
|
||||
if transformer_mlp_head_units is None:
|
||||
transformer_mlp_head_units = [128, 64]
|
||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||
|
||||
##transformer_units = [
|
||||
|
|
@ -536,27 +560,32 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size
|
|||
IMAGE_ORDERING = 'channels_last'
|
||||
bn_axis=3
|
||||
|
||||
patches = Patches(patch_size_x, patch_size_y)(inputs)
|
||||
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
||||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = layers.MultiHeadAttention(
|
||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = layers.Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = layers.Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||
encoded_patches = tf.reshape(encoded_patches,
|
||||
[-1,
|
||||
input_height,
|
||||
input_width,
|
||||
transformer_projection_dim // (transformer_patchsize_x *
|
||||
transformer_patchsize_y)])
|
||||
|
||||
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from tensorflow.keras.metrics import MeanIoU
|
|||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||
from sacred import Experiment
|
||||
from sacred.config import create_captured_function
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
|
|
@ -318,7 +319,7 @@ def run(_config,
|
|||
task,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
elif backbone_type == 'transformer':
|
||||
else:
|
||||
num_patches_x = transformer_num_patches_xy[0]
|
||||
num_patches_y = transformer_num_patches_xy[1]
|
||||
num_patches = num_patches_x * num_patches_y
|
||||
|
|
@ -330,35 +331,31 @@ def run(_config,
|
|||
model_builder = vit_resnet50_unet_transformer_before_cnn
|
||||
multiple_of_32 = False
|
||||
|
||||
assert input_height == num_patches_y * transformer_patchsize_y * (32 if multiple_of_32 else 1), \
|
||||
assert input_height == (num_patches_y *
|
||||
transformer_patchsize_y *
|
||||
(32 if multiple_of_32 else 1)), \
|
||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: " \
|
||||
"input_height should be equal to " \
|
||||
"(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \
|
||||
" * 32" if multiple_of_32 else ""
|
||||
assert input_width == num_patches_x * transformer_patchsize_x * (32 if multiple_of_32 else 1), \
|
||||
assert input_width == (num_patches_x *
|
||||
transformer_patchsize_x *
|
||||
(32 if multiple_of_32 else 1)), \
|
||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: " \
|
||||
"input_width should be equal to " \
|
||||
"(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \
|
||||
" * 32" if multiple_of_32 else ""
|
||||
assert 0 == transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x), \
|
||||
assert 0 == (transformer_projection_dim %
|
||||
(transformer_patchsize_y *
|
||||
transformer_patchsize_x)), \
|
||||
"transformer_projection_dim error: " \
|
||||
"The remainder when parameter transformer_projection_dim is divided by " \
|
||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero"
|
||||
|
||||
model = model_builder(
|
||||
n_classes,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
num_patches,
|
||||
transformer_mlp_head_units,
|
||||
transformer_layers,
|
||||
transformer_num_heads,
|
||||
transformer_projection_dim,
|
||||
input_height,
|
||||
input_width,
|
||||
task,
|
||||
weight_decay,
|
||||
pretraining)
|
||||
model_builder = create_captured_function(model_builder)
|
||||
model_builder.config = _config
|
||||
model_builder.logger = _log
|
||||
model = model_builder(num_patches)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue