mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.train: simplify config args for model builder
This commit is contained in:
parent
4a65ee0c67
commit
5c7801a1d6
2 changed files with 63 additions and 37 deletions
|
|
@ -400,9 +400,21 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
def vit_resnet50_unet(num_patches,
|
||||||
if mlp_head_units is None:
|
n_classes,
|
||||||
mlp_head_units = [128, 64]
|
transformer_patchsize_x,
|
||||||
|
transformer_patchsize_y,
|
||||||
|
transformer_mlp_head_units=None,
|
||||||
|
transformer_layers=8,
|
||||||
|
transformer_num_heads=4,
|
||||||
|
transformer_projection_dim=64,
|
||||||
|
input_height=224,
|
||||||
|
input_width=224,
|
||||||
|
task="segmentation",
|
||||||
|
weight_decay=1e-6,
|
||||||
|
pretraining=False):
|
||||||
|
if transformer_mlp_head_units is None:
|
||||||
|
transformer_mlp_head_units = [128, 64]
|
||||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
#transformer_units = [
|
#transformer_units = [
|
||||||
|
|
@ -449,30 +461,30 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
||||||
|
|
||||||
#num_patches = x.shape[1]*x.shape[2]
|
#num_patches = x.shape[1]*x.shape[2]
|
||||||
|
|
||||||
#patch_size_y = input_height / x.shape[1]
|
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(x)
|
||||||
#patch_size_x = input_width / x.shape[2]
|
|
||||||
#patch_size = patch_size_x * patch_size_y
|
|
||||||
patches = Patches(patch_size_x, patch_size_y)(x)
|
|
||||||
# Encode patches.
|
# Encode patches.
|
||||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
||||||
|
|
||||||
for _ in range(transformer_layers):
|
for _ in range(transformer_layers):
|
||||||
# Layer normalization 1.
|
# Layer normalization 1.
|
||||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||||
# Create a multi-head attention layer.
|
# Create a multi-head attention layer.
|
||||||
attention_output = layers.MultiHeadAttention(
|
attention_output = layers.MultiHeadAttention(
|
||||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
|
||||||
)(x1, x1)
|
)(x1, x1)
|
||||||
# Skip connection 1.
|
# Skip connection 1.
|
||||||
x2 = layers.Add()([attention_output, encoded_patches])
|
x2 = layers.Add()([attention_output, encoded_patches])
|
||||||
# Layer normalization 2.
|
# Layer normalization 2.
|
||||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||||
# MLP.
|
# MLP.
|
||||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
|
||||||
# Skip connection 2.
|
# Skip connection 2.
|
||||||
encoded_patches = layers.Add()([x3, x2])
|
encoded_patches = layers.Add()([x3, x2])
|
||||||
|
|
||||||
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
|
encoded_patches = tf.reshape(encoded_patches,
|
||||||
|
[-1, x.shape[1], x.shape[2],
|
||||||
|
transformer_projection_dim // (transformer_patchsize_x *
|
||||||
|
transformer_patchsize_y)])
|
||||||
|
|
||||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||||
|
|
@ -524,9 +536,21 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
def vit_resnet50_unet_transformer_before_cnn(num_patches,
|
||||||
if mlp_head_units is None:
|
n_classes,
|
||||||
mlp_head_units = [128, 64]
|
transformer_patchsize_x,
|
||||||
|
transformer_patchsize_y,
|
||||||
|
transformer_mlp_head_units=None,
|
||||||
|
transformer_layers=8,
|
||||||
|
transformer_num_heads=4,
|
||||||
|
transformer_projection_dim=64,
|
||||||
|
input_height=224,
|
||||||
|
input_width=224,
|
||||||
|
task="segmentation",
|
||||||
|
weight_decay=1e-6,
|
||||||
|
pretraining=False):
|
||||||
|
if transformer_mlp_head_units is None:
|
||||||
|
transformer_mlp_head_units = [128, 64]
|
||||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
##transformer_units = [
|
##transformer_units = [
|
||||||
|
|
@ -536,27 +560,32 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size
|
||||||
IMAGE_ORDERING = 'channels_last'
|
IMAGE_ORDERING = 'channels_last'
|
||||||
bn_axis=3
|
bn_axis=3
|
||||||
|
|
||||||
patches = Patches(patch_size_x, patch_size_y)(inputs)
|
patches = Patches(transformer_patchsize_x, transformer_patchsize_y)(inputs)
|
||||||
# Encode patches.
|
# Encode patches.
|
||||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
encoded_patches = PatchEncoder(num_patches, transformer_projection_dim)(patches)
|
||||||
|
|
||||||
for _ in range(transformer_layers):
|
for _ in range(transformer_layers):
|
||||||
# Layer normalization 1.
|
# Layer normalization 1.
|
||||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||||
# Create a multi-head attention layer.
|
# Create a multi-head attention layer.
|
||||||
attention_output = layers.MultiHeadAttention(
|
attention_output = layers.MultiHeadAttention(
|
||||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
num_heads=transformer_num_heads, key_dim=transformer_projection_dim, dropout=0.1
|
||||||
)(x1, x1)
|
)(x1, x1)
|
||||||
# Skip connection 1.
|
# Skip connection 1.
|
||||||
x2 = layers.Add()([attention_output, encoded_patches])
|
x2 = layers.Add()([attention_output, encoded_patches])
|
||||||
# Layer normalization 2.
|
# Layer normalization 2.
|
||||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||||
# MLP.
|
# MLP.
|
||||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
x3 = mlp(x3, hidden_units=transformer_mlp_head_units, dropout_rate=0.1)
|
||||||
# Skip connection 2.
|
# Skip connection 2.
|
||||||
encoded_patches = layers.Add()([x3, x2])
|
encoded_patches = layers.Add()([x3, x2])
|
||||||
|
|
||||||
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
|
encoded_patches = tf.reshape(encoded_patches,
|
||||||
|
[-1,
|
||||||
|
input_height,
|
||||||
|
input_width,
|
||||||
|
transformer_projection_dim // (transformer_patchsize_x *
|
||||||
|
transformer_patchsize_y)])
|
||||||
|
|
||||||
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)
|
encoded_patches = Conv2D(3, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay), name='convinput')(encoded_patches)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ from tensorflow.keras.metrics import MeanIoU
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
|
from sacred.config import create_captured_function
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sklearn.metrics import f1_score
|
from sklearn.metrics import f1_score
|
||||||
|
|
||||||
|
|
@ -318,7 +319,7 @@ def run(_config,
|
||||||
task,
|
task,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
pretraining)
|
pretraining)
|
||||||
elif backbone_type == 'transformer':
|
else:
|
||||||
num_patches_x = transformer_num_patches_xy[0]
|
num_patches_x = transformer_num_patches_xy[0]
|
||||||
num_patches_y = transformer_num_patches_xy[1]
|
num_patches_y = transformer_num_patches_xy[1]
|
||||||
num_patches = num_patches_x * num_patches_y
|
num_patches = num_patches_x * num_patches_y
|
||||||
|
|
@ -330,35 +331,31 @@ def run(_config,
|
||||||
model_builder = vit_resnet50_unet_transformer_before_cnn
|
model_builder = vit_resnet50_unet_transformer_before_cnn
|
||||||
multiple_of_32 = False
|
multiple_of_32 = False
|
||||||
|
|
||||||
assert input_height == num_patches_y * transformer_patchsize_y * (32 if multiple_of_32 else 1), \
|
assert input_height == (num_patches_y *
|
||||||
|
transformer_patchsize_y *
|
||||||
|
(32 if multiple_of_32 else 1)), \
|
||||||
"transformer_patchsize_y or transformer_num_patches_xy height value error: " \
|
"transformer_patchsize_y or transformer_num_patches_xy height value error: " \
|
||||||
"input_height should be equal to " \
|
"input_height should be equal to " \
|
||||||
"(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \
|
"(transformer_num_patches_xy height value * transformer_patchsize_y%s)" % \
|
||||||
" * 32" if multiple_of_32 else ""
|
" * 32" if multiple_of_32 else ""
|
||||||
assert input_width == num_patches_x * transformer_patchsize_x * (32 if multiple_of_32 else 1), \
|
assert input_width == (num_patches_x *
|
||||||
|
transformer_patchsize_x *
|
||||||
|
(32 if multiple_of_32 else 1)), \
|
||||||
"transformer_patchsize_x or transformer_num_patches_xy width value error: " \
|
"transformer_patchsize_x or transformer_num_patches_xy width value error: " \
|
||||||
"input_width should be equal to " \
|
"input_width should be equal to " \
|
||||||
"(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \
|
"(transformer_num_patches_xy width value * transformer_patchsize_x%s)" % \
|
||||||
" * 32" if multiple_of_32 else ""
|
" * 32" if multiple_of_32 else ""
|
||||||
assert 0 == transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x), \
|
assert 0 == (transformer_projection_dim %
|
||||||
|
(transformer_patchsize_y *
|
||||||
|
transformer_patchsize_x)), \
|
||||||
"transformer_projection_dim error: " \
|
"transformer_projection_dim error: " \
|
||||||
"The remainder when parameter transformer_projection_dim is divided by " \
|
"The remainder when parameter transformer_projection_dim is divided by " \
|
||||||
"(transformer_patchsize_y*transformer_patchsize_x) should be zero"
|
"(transformer_patchsize_y*transformer_patchsize_x) should be zero"
|
||||||
|
|
||||||
model = model_builder(
|
model_builder = create_captured_function(model_builder)
|
||||||
n_classes,
|
model_builder.config = _config
|
||||||
transformer_patchsize_x,
|
model_builder.logger = _log
|
||||||
transformer_patchsize_y,
|
model = model_builder(num_patches)
|
||||||
num_patches,
|
|
||||||
transformer_mlp_head_units,
|
|
||||||
transformer_layers,
|
|
||||||
transformer_num_heads,
|
|
||||||
transformer_projection_dim,
|
|
||||||
input_height,
|
|
||||||
input_width,
|
|
||||||
task,
|
|
||||||
weight_decay,
|
|
||||||
pretraining)
|
|
||||||
|
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
#model.summary()
|
#model.summary()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue