mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 11:50:04 +02:00
updating train.py nontransformer backend
This commit is contained in:
parent
815e5a1d35
commit
41a0e15e79
2 changed files with 18 additions and 7 deletions
13
models.py
13
models.py
|
@ -30,8 +30,8 @@ class Patches(layers.Layer):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
|
||||||
def call(self, images):
|
def call(self, images):
|
||||||
print(tf.shape(images)[1],'images')
|
#print(tf.shape(images)[1],'images')
|
||||||
print(self.patch_size,'self.patch_size')
|
#print(self.patch_size,'self.patch_size')
|
||||||
batch_size = tf.shape(images)[0]
|
batch_size = tf.shape(images)[0]
|
||||||
patches = tf.image.extract_patches(
|
patches = tf.image.extract_patches(
|
||||||
images=images,
|
images=images,
|
||||||
|
@ -41,7 +41,7 @@ class Patches(layers.Layer):
|
||||||
padding="VALID",
|
padding="VALID",
|
||||||
)
|
)
|
||||||
patch_dims = patches.shape[-1]
|
patch_dims = patches.shape[-1]
|
||||||
print(patches.shape,patch_dims,'patch_dims')
|
#print(patches.shape,patch_dims,'patch_dims')
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
return patches
|
return patches
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
|
@ -52,6 +52,7 @@ class Patches(layers.Layer):
|
||||||
})
|
})
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
class PatchEncoder(layers.Layer):
|
||||||
def __init__(self, num_patches, projection_dim):
|
def __init__(self, num_patches, projection_dim):
|
||||||
super(PatchEncoder, self).__init__()
|
super(PatchEncoder, self).__init__()
|
||||||
|
@ -408,7 +409,11 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
|
||||||
if pretraining:
|
if pretraining:
|
||||||
model = Model(inputs, x).load_weights(resnet50_Weights_path)
|
model = Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||||
|
|
||||||
num_patches = x.shape[1]*x.shape[2]
|
#num_patches = x.shape[1]*x.shape[2]
|
||||||
|
|
||||||
|
#patch_size_y = input_height / x.shape[1]
|
||||||
|
#patch_size_x = input_width / x.shape[2]
|
||||||
|
#patch_size = patch_size_x * patch_size_y
|
||||||
patches = Patches(patch_size)(x)
|
patches = Patches(patch_size)(x)
|
||||||
# Encode patches.
|
# Encode patches.
|
||||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||||
|
|
12
train.py
12
train.py
|
@ -97,8 +97,6 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
||||||
|
|
||||||
if task == "segmentation" or task == "enhancement":
|
if task == "segmentation" or task == "enhancement":
|
||||||
|
|
||||||
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
|
|
||||||
if data_is_provided:
|
if data_is_provided:
|
||||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||||
|
@ -213,7 +211,15 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
index_start = 0
|
index_start = 0
|
||||||
if backbone_type=='nontransformer':
|
if backbone_type=='nontransformer':
|
||||||
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||||
elif backbone_type=='nontransformer':
|
elif backbone_type=='transformer':
|
||||||
|
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
|
||||||
|
|
||||||
|
if not (num_patches == (input_width / 32) * (input_height / 32)):
|
||||||
|
print("Error: transformer num patches error. Parameter transformer_num_patches_xy should be set to (input_width/32) = {} and (input_height/32) = {}".format(int(input_width / 32), int(input_height / 32)) )
|
||||||
|
sys.exit(1)
|
||||||
|
if not (transformer_patchsize == 1):
|
||||||
|
print("Error: transformer patchsize error. Parameter transformer_patchsizeshould set to 1" )
|
||||||
|
sys.exit(1)
|
||||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
||||||
|
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue