mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 03:40:24 +02:00
transformer patch size is dynamic now.
This commit is contained in:
parent
2aa216e388
commit
f1fd74c7eb
3 changed files with 75 additions and 30 deletions
|
@ -1,42 +1,44 @@
|
||||||
{
|
{
|
||||||
"backbone_type" : "nontransformer",
|
"backbone_type" : "transformer",
|
||||||
"task": "classification",
|
"task": "binarization",
|
||||||
"n_classes" : 2,
|
"n_classes" : 2,
|
||||||
"n_epochs" : 20,
|
"n_epochs" : 1,
|
||||||
"input_height" : 448,
|
"input_height" : 224,
|
||||||
"input_width" : 448,
|
"input_width" : 672,
|
||||||
"weight_decay" : 1e-6,
|
"weight_decay" : 1e-6,
|
||||||
"n_batch" : 6,
|
"n_batch" : 1,
|
||||||
"learning_rate": 1e-4,
|
"learning_rate": 1e-4,
|
||||||
"f1_threshold_classification": 0.8,
|
|
||||||
"patches" : true,
|
"patches" : true,
|
||||||
"pretraining" : true,
|
"pretraining" : true,
|
||||||
"augmentation" : false,
|
"augmentation" : false,
|
||||||
"flip_aug" : false,
|
"flip_aug" : false,
|
||||||
"blur_aug" : false,
|
"blur_aug" : false,
|
||||||
"scaling" : true,
|
"scaling" : true,
|
||||||
|
"degrading": false,
|
||||||
|
"brightening": false,
|
||||||
"binarization" : false,
|
"binarization" : false,
|
||||||
"scaling_bluring" : false,
|
"scaling_bluring" : false,
|
||||||
"scaling_binarization" : false,
|
"scaling_binarization" : false,
|
||||||
"scaling_flip" : false,
|
"scaling_flip" : false,
|
||||||
"rotation": false,
|
"rotation": false,
|
||||||
"rotation_not_90": false,
|
"rotation_not_90": false,
|
||||||
"transformer_num_patches_xy": [28, 28],
|
"transformer_num_patches_xy": [7, 7],
|
||||||
"transformer_patchsize": 1,
|
"transformer_patchsize_x": 3,
|
||||||
|
"transformer_patchsize_y": 1,
|
||||||
|
"transformer_projection_dim": 192,
|
||||||
"blur_k" : ["blur","guass","median"],
|
"blur_k" : ["blur","guass","median"],
|
||||||
"scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
|
"scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
|
||||||
"brightness" : [1.3, 1.5, 1.7, 2],
|
"brightness" : [1.3, 1.5, 1.7, 2],
|
||||||
"degrade_scales" : [0.2, 0.4],
|
"degrade_scales" : [0.2, 0.4],
|
||||||
"flip_index" : [0, 1, -1],
|
"flip_index" : [0, 1, -1],
|
||||||
"thetha" : [10, -10],
|
"thetha" : [10, -10],
|
||||||
"classification_classes_name" : {"0":"apple", "1":"orange"},
|
|
||||||
"continue_training": false,
|
"continue_training": false,
|
||||||
"index_start" : 0,
|
"index_start" : 0,
|
||||||
"dir_of_start_model" : " ",
|
"dir_of_start_model" : " ",
|
||||||
"weighted_loss": false,
|
"weighted_loss": false,
|
||||||
"is_loss_soft_dice": false,
|
"is_loss_soft_dice": false,
|
||||||
"data_is_provided": false,
|
"data_is_provided": false,
|
||||||
"dir_train": "./train",
|
"dir_train": "/home/vahid/Documents/test/training_data_sample_binarization",
|
||||||
"dir_eval": "./eval",
|
"dir_eval": "/home/vahid/Documents/test/eval",
|
||||||
"dir_output": "./output"
|
"dir_output": "/home/vahid/Documents/test/out"
|
||||||
}
|
}
|
||||||
|
|
47
models.py
47
models.py
|
@ -6,25 +6,49 @@ from tensorflow.keras import layers
|
||||||
from tensorflow.keras.regularizers import l2
|
from tensorflow.keras.regularizers import l2
|
||||||
|
|
||||||
mlp_head_units = [2048, 1024]
|
mlp_head_units = [2048, 1024]
|
||||||
projection_dim = 64
|
#projection_dim = 64
|
||||||
transformer_layers = 8
|
transformer_layers = 8
|
||||||
num_heads = 4
|
num_heads = 4
|
||||||
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||||
IMAGE_ORDERING = 'channels_last'
|
IMAGE_ORDERING = 'channels_last'
|
||||||
MERGE_AXIS = -1
|
MERGE_AXIS = -1
|
||||||
|
|
||||||
transformer_units = [
|
|
||||||
projection_dim * 2,
|
|
||||||
projection_dim,
|
|
||||||
] # Size of the transformer layers
|
|
||||||
def mlp(x, hidden_units, dropout_rate):
|
def mlp(x, hidden_units, dropout_rate):
|
||||||
for units in hidden_units:
|
for units in hidden_units:
|
||||||
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
||||||
x = layers.Dropout(dropout_rate)(x)
|
x = layers.Dropout(dropout_rate)(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
class Patches(layers.Layer):
|
||||||
|
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||||
|
super(Patches, self).__init__()
|
||||||
|
self.patch_size_x = patch_size_x
|
||||||
|
self.patch_size_y = patch_size_y
|
||||||
|
|
||||||
|
def call(self, images):
|
||||||
|
#print(tf.shape(images)[1],'images')
|
||||||
|
#print(self.patch_size,'self.patch_size')
|
||||||
|
batch_size = tf.shape(images)[0]
|
||||||
|
patches = tf.image.extract_patches(
|
||||||
|
images=images,
|
||||||
|
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
|
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
)
|
||||||
|
patch_dims = patches.shape[-1]
|
||||||
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
|
return patches
|
||||||
|
def get_config(self):
|
||||||
|
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'patch_size_x': self.patch_size_x,
|
||||||
|
'patch_size_y': self.patch_size_y,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
class Patches_old(layers.Layer):
|
||||||
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||||
super(Patches, self).__init__()
|
super(Patches, self).__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
@ -369,8 +393,13 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
|
transformer_units = [
|
||||||
|
projection_dim * 2,
|
||||||
|
projection_dim,
|
||||||
|
] # Size of the transformer layers
|
||||||
IMAGE_ORDERING = 'channels_last'
|
IMAGE_ORDERING = 'channels_last'
|
||||||
bn_axis=3
|
bn_axis=3
|
||||||
|
|
||||||
|
@ -414,7 +443,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
|
||||||
#patch_size_y = input_height / x.shape[1]
|
#patch_size_y = input_height / x.shape[1]
|
||||||
#patch_size_x = input_width / x.shape[2]
|
#patch_size_x = input_width / x.shape[2]
|
||||||
#patch_size = patch_size_x * patch_size_y
|
#patch_size = patch_size_x * patch_size_y
|
||||||
patches = Patches(patch_size)(x)
|
patches = Patches(patch_size_x, patch_size_y)(x)
|
||||||
# Encode patches.
|
# Encode patches.
|
||||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||||
|
|
||||||
|
@ -434,7 +463,7 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
|
||||||
# Skip connection 2.
|
# Skip connection 2.
|
||||||
encoded_patches = layers.Add()([x3, x2])
|
encoded_patches = layers.Add()([x3, x2])
|
||||||
|
|
||||||
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64])
|
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||||
|
|
||||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||||
|
|
30
train.py
30
train.py
|
@ -70,8 +70,10 @@ def config_params():
|
||||||
brightness = None # Brighten image for augmentation.
|
brightness = None # Brighten image for augmentation.
|
||||||
flip_index = None # Flip image for augmentation.
|
flip_index = None # Flip image for augmentation.
|
||||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
||||||
transformer_patchsize = None # Patch size of vision transformer patches.
|
transformer_patchsize_x = None # Patch size of vision transformer patches.
|
||||||
|
transformer_patchsize_y = None
|
||||||
transformer_num_patches_xy = None # Number of patches for vision transformer.
|
transformer_num_patches_xy = None # Number of patches for vision transformer.
|
||||||
|
transformer_projection_dim = 64 # Transformer projection dimension
|
||||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
||||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||||
|
@ -92,7 +94,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
brightening, binarization, blur_k, scales, degrade_scales,
|
brightening, binarization, blur_k, scales, degrade_scales,
|
||||||
brightness, dir_train, data_is_provided, scaling_bluring,
|
brightness, dir_train, data_is_provided, scaling_bluring,
|
||||||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||||
thetha, scaling_flip, continue_training, transformer_patchsize,
|
thetha, scaling_flip, continue_training, transformer_projection_dim, transformer_patchsize_x, transformer_patchsize_y,
|
||||||
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
||||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
||||||
|
|
||||||
|
@ -212,15 +214,27 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
if backbone_type=='nontransformer':
|
if backbone_type=='nontransformer':
|
||||||
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||||
elif backbone_type=='transformer':
|
elif backbone_type=='transformer':
|
||||||
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
|
num_patches_x = transformer_num_patches_xy[0]
|
||||||
|
num_patches_y = transformer_num_patches_xy[1]
|
||||||
|
num_patches = num_patches_x * num_patches_y
|
||||||
|
|
||||||
if not (num_patches == (input_width / 32) * (input_height / 32)):
|
##if not (num_patches == (input_width / 32) * (input_height / 32)):
|
||||||
print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) )
|
##print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) )
|
||||||
|
##sys.exit(1)
|
||||||
|
#if not (transformer_patchsize == 1):
|
||||||
|
#print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" )
|
||||||
|
#sys.exit(1)
|
||||||
|
if (input_height != (num_patches_y * transformer_patchsize_y * 32) ):
|
||||||
|
print("Error: transformer_patchsize_y or transformer_num_patches_xy height value error . input_height should be equal to ( transformer_num_patches_xy height value * transformer_patchsize_y * 32)")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if not (transformer_patchsize == 1):
|
if (input_width != (num_patches_x * transformer_patchsize_x * 32) ):
|
||||||
print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" )
|
print("Error: transformer_patchsize_x or transformer_num_patches_xy width value error . input_width should be equal to ( transformer_num_patches_xy width value * transformer_patchsize_x * 32)")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
if (transformer_projection_dim % (transformer_patchsize_y * transformer_patchsize_x)) != 0:
|
||||||
|
print("Error: transformer_projection_dim error. The remainder when parameter transformer_projection_dim is divided by (transformer_patchsize_y*transformer_patchsize_x) should be zero")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
model = vit_resnet50_unet(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
|
||||||
|
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
#model.summary()
|
#model.summary()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue