wrap_layout_model_patched: simplify shape calculation

This commit is contained in:
Robert Sachunsky 2026-03-14 00:51:22 +01:00
parent d6404dbbc2
commit b550725cc5

View file

@ -122,8 +122,10 @@ class wrap_layout_model_patched(models.Model):
rates=[1, 1, 1, 1],
padding='SAME')
img_patches = tf.squeeze(img_patches)
new_shape = (-1, self.height, self.width, 3)
img_patches = tf.reshape(img_patches, shape=new_shape)
index_shape = (-1, self.height, self.width, 2)
input_shape = (-1, self.height, self.width, 3)
output_shape = (-1, self.height, self.width, self.classes)
img_patches = tf.reshape(img_patches, shape=input_shape)
# may be too large:
#pred_patches = self.model(img_patches)
# so rebatch to fit in memory:
@ -144,7 +146,7 @@ class wrap_layout_model_patched(models.Model):
rates=[1, 1, 1, 1],
padding='SAME')
indices_patches = tf.squeeze(indices_patches)
indices_patches = tf.reshape(indices_patches, shape=new_shape[:-1] + (2,))
indices_patches = tf.reshape(indices_patches, shape=index_shape)
# use margins for sliding window approach
indices_patches = indices_patches * self.window