mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-31 01:59:27 +02:00
models: cosmetics
- using `Reshape`, do not pass `target_shape` as kwarg - add a default `name` for `Patches` and `PatchEncoder`
This commit is contained in:
parent
9801129aa6
commit
c4a7eec5b3
2 changed files with 14 additions and 14 deletions
|
|
@ -6,8 +6,8 @@ from tensorflow.keras import layers, models
|
||||||
class PatchEncoder(layers.Layer):
|
class PatchEncoder(layers.Layer):
|
||||||
|
|
||||||
# 441=21*21 # 14*14 # 28*28
|
# 441=21*21 # 14*14 # 28*28
|
||||||
def __init__(self, num_patches=441, projection_dim=64):
|
def __init__(self, num_patches=441, projection_dim=64, name='encode_patches'):
|
||||||
super().__init__()
|
super().__init__(name=name)
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.projection_dim = projection_dim
|
self.projection_dim = projection_dim
|
||||||
self.projection = layers.Dense(self.projection_dim)
|
self.projection = layers.Dense(self.projection_dim)
|
||||||
|
|
@ -23,8 +23,8 @@ class PatchEncoder(layers.Layer):
|
||||||
**super().get_config())
|
**super().get_config())
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
class Patches(layers.Layer):
|
||||||
def __init__(self, patch_size_x=1, patch_size_y=1):
|
def __init__(self, patch_size_x=1, patch_size_y=1, name='extract_patches'):
|
||||||
super().__init__()
|
super().__init__(name=name)
|
||||||
self.patch_size_x = patch_size_x
|
self.patch_size_x = patch_size_x
|
||||||
self.patch_size_y = patch_size_y
|
self.patch_size_y = patch_size_y
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -309,9 +309,9 @@ def transformer_block(img,
|
||||||
# Skip connection 2.
|
# Skip connection 2.
|
||||||
encoded_patches = Add()([x3, x2])
|
encoded_patches = Add()([x3, x2])
|
||||||
|
|
||||||
encoded_patches = Reshape(target_shape=(img.shape[1],
|
encoded_patches = Reshape((img.shape[1],
|
||||||
img.shape[2],
|
img.shape[2],
|
||||||
projection_dim // (patchsize_x * patchsize_y)),
|
projection_dim // (patchsize_x * patchsize_y)),
|
||||||
name="reshape_patches")(encoded_patches)
|
name="reshape_patches")(encoded_patches)
|
||||||
return encoded_patches
|
return encoded_patches
|
||||||
|
|
||||||
|
|
@ -464,23 +464,23 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
||||||
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
|
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
|
||||||
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
|
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
|
||||||
|
|
||||||
x = Reshape(target_shape=new_shape, name="reshape")(x)
|
x = Reshape(new_shape, name="reshape")(x)
|
||||||
x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d)
|
x2d = Reshape(new_shape2, name="reshape2")(x2d)
|
||||||
x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d)
|
x4d = Reshape(new_shape4, name="reshape4")(x4d)
|
||||||
|
|
||||||
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
||||||
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
||||||
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
||||||
|
|
||||||
xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||||
xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||||
|
|
||||||
|
|
||||||
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
||||||
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
||||||
|
|
||||||
xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
xrnn2dup = Reshape((xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
||||||
xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
xrnn4dup = Reshape((xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
||||||
|
|
||||||
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue