diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py index fda1d6d..d6a74ea 100644 --- a/src/eynollah/patch_encoder.py +++ b/src/eynollah/patch_encoder.py @@ -122,8 +122,10 @@ class wrap_layout_model_patched(models.Model): rates=[1, 1, 1, 1], padding='SAME') img_patches = tf.squeeze(img_patches) - new_shape = (-1, self.height, self.width, 3) - img_patches = tf.reshape(img_patches, shape=new_shape) + index_shape = (-1, self.height, self.width, 2) + input_shape = (-1, self.height, self.width, 3) + output_shape = (-1, self.height, self.width, self.classes) + img_patches = tf.reshape(img_patches, shape=input_shape) # may be too large: #pred_patches = self.model(img_patches) # so rebatch to fit in memory: @@ -144,7 +146,7 @@ class wrap_layout_model_patched(models.Model): rates=[1, 1, 1, 1], padding='SAME') indices_patches = tf.squeeze(indices_patches) - indices_patches = tf.reshape(indices_patches, shape=new_shape[:-1] + (2,)) + indices_patches = tf.reshape(indices_patches, shape=index_shape) # use margins for sliding window approach indices_patches = indices_patches * self.window