models: cosmetics

- using `Reshape`, do not pass `target_shape` as kwarg
- add a default `name` for `Patches` and `PatchEncoder`
This commit is contained in:
Robert Sachunsky 2026-05-27 01:58:21 +02:00
parent 9801129aa6
commit c4a7eec5b3
2 changed files with 14 additions and 14 deletions

View file

@ -6,8 +6,8 @@ from tensorflow.keras import layers, models
class PatchEncoder(layers.Layer):
# 441=21*21 # 14*14 # 28*28
def __init__(self, num_patches=441, projection_dim=64):
super().__init__()
def __init__(self, num_patches=441, projection_dim=64, name='encode_patches'):
super().__init__(name=name)
self.num_patches = num_patches
self.projection_dim = projection_dim
self.projection = layers.Dense(self.projection_dim)
@ -23,8 +23,8 @@ class PatchEncoder(layers.Layer):
**super().get_config())
class Patches(layers.Layer):
def __init__(self, patch_size_x=1, patch_size_y=1):
super().__init__()
def __init__(self, patch_size_x=1, patch_size_y=1, name='extract_patches'):
super().__init__(name=name)
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y

View file

@ -309,9 +309,9 @@ def transformer_block(img,
# Skip connection 2.
encoded_patches = Add()([x3, x2])
encoded_patches = Reshape(target_shape=(img.shape[1],
img.shape[2],
projection_dim // (patchsize_x * patchsize_y)),
encoded_patches = Reshape((img.shape[1],
img.shape[2],
projection_dim // (patchsize_x * patchsize_y)),
name="reshape_patches")(encoded_patches)
return encoded_patches
@ -464,23 +464,23 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
x = Reshape(target_shape=new_shape, name="reshape")(x)
x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d)
x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d)
x = Reshape(new_shape, name="reshape")(x)
x2d = Reshape(new_shape2, name="reshape2")(x2d)
x4d = Reshape(new_shape4, name="reshape4")(x4d)
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
xrnn2dup = Reshape((xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
xrnn4dup = Reshape((xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])