mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 11:50:04 +02:00
updating train.py nontransformer backend
This commit is contained in:
parent
815e5a1d35
commit
41a0e15e79
2 changed files with 18 additions and 7 deletions
13
models.py
13
models.py
|
@ -30,8 +30,8 @@ class Patches(layers.Layer):
|
|||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
print(tf.shape(images)[1],'images')
|
||||
print(self.patch_size,'self.patch_size')
|
||||
#print(tf.shape(images)[1],'images')
|
||||
#print(self.patch_size,'self.patch_size')
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
|
@ -41,7 +41,7 @@ class Patches(layers.Layer):
|
|||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
print(patches.shape,patch_dims,'patch_dims')
|
||||
#print(patches.shape,patch_dims,'patch_dims')
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
@ -51,6 +51,7 @@ class Patches(layers.Layer):
|
|||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, num_patches, projection_dim):
|
||||
|
@ -408,7 +409,11 @@ def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, inpu
|
|||
if pretraining:
|
||||
model = Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
num_patches = x.shape[1]*x.shape[2]
|
||||
#num_patches = x.shape[1]*x.shape[2]
|
||||
|
||||
#patch_size_y = input_height / x.shape[1]
|
||||
#patch_size_x = input_width / x.shape[2]
|
||||
#patch_size = patch_size_x * patch_size_y
|
||||
patches = Patches(patch_size)(x)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue