From c4a7eec5b3195cd3114a0d2b54de723c806e5bab Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 27 May 2026 01:58:21 +0200 Subject: [PATCH] models: cosmetics - using `Reshape`, do not pass `target_shape` as kwarg - add a default `name` for `Patches` and `PatchEncoder` --- src/eynollah/patch_encoder.py | 8 ++++---- src/eynollah/training/models.py | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py index f163132..610f0b4 100644 --- a/src/eynollah/patch_encoder.py +++ b/src/eynollah/patch_encoder.py @@ -6,8 +6,8 @@ from tensorflow.keras import layers, models class PatchEncoder(layers.Layer): # 441=21*21 # 14*14 # 28*28 - def __init__(self, num_patches=441, projection_dim=64): - super().__init__() + def __init__(self, num_patches=441, projection_dim=64, name='encode_patches'): + super().__init__(name=name) self.num_patches = num_patches self.projection_dim = projection_dim self.projection = layers.Dense(self.projection_dim) @@ -23,8 +23,8 @@ class PatchEncoder(layers.Layer): **super().get_config()) class Patches(layers.Layer): - def __init__(self, patch_size_x=1, patch_size_y=1): - super().__init__() + def __init__(self, patch_size_x=1, patch_size_y=1, name='extract_patches'): + super().__init__(name=name) self.patch_size_x = patch_size_x self.patch_size_y = patch_size_y diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index f700d14..c5510f8 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -309,9 +309,9 @@ def transformer_block(img, # Skip connection 2. encoded_patches = Add()([x3, x2]) - encoded_patches = Reshape(target_shape=(img.shape[1], - img.shape[2], - projection_dim // (patchsize_x * patchsize_y)), + encoded_patches = Reshape((img.shape[1], + img.shape[2], + projection_dim // (patchsize_x * patchsize_y)), name="reshape_patches")(encoded_patches) return encoded_patches @@ -464,23 +464,23 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3]) new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3]) - x = Reshape(target_shape=new_shape, name="reshape")(x) - x2d = Reshape(target_shape=new_shape2, name="reshape2")(x2d) - x4d = Reshape(target_shape=new_shape4, name="reshape4")(x4d) + x = Reshape(new_shape, name="reshape")(x) + x2d = Reshape(new_shape2, name="reshape2")(x2d) + x4d = Reshape(new_shape4, name="reshape4")(x4d) xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x) xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d) xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d) - xrnn2d = Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d) - xrnn4d = Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d) + xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d) + xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d) xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d) xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d) - xrnn2dup = Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup) - xrnn4dup = Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup) + xrnn2dup = Reshape((xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup) + xrnn4dup = Reshape((xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup) addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])