From 6f4ec53f7e90f65b8435949d89ca6ce74b711f2f Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sat, 7 Mar 2026 03:52:14 +0100 Subject: [PATCH] wrap_layout_model_resized/patched: compile `call` instead of `predict` (so `predict()` can directly convert back to Numpy) --- src/eynollah/eynollah.py | 7 +------ src/eynollah/patch_encoder.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index d536cc8..1f7d585 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -104,11 +104,6 @@ DPI_THRESHOLD = 298 MAX_SLOPE = 999 KERNEL = np.ones((5, 5), np.uint8) -projection_dim = 64 -patch_size = 1 -num_patches =21*21#14*14#28*28#14*14#28*28 - - class Eynollah: def __init__( @@ -925,7 +920,7 @@ class Eynollah: self.logger.debug("enter do_prediction_new_concept (%s)", model.name) img = img / 255.0 - prediction = model.predict(img[np.newaxis]).numpy()[0] + prediction = model.predict(img[np.newaxis])[0] confidence = prediction[:, :, 1] segmentation = np.argmax(prediction, axis=2).astype(np.uint8) diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py index 20b71d6..fda1d6d 100644 --- a/src/eynollah/patch_encoder.py +++ b/src/eynollah/patch_encoder.py @@ -57,6 +57,10 @@ class wrap_layout_model_resized(models.Model): self.height = model.layers[-1].output_shape[1] self.width = model.layers[-1].output_shape[2] + @tf.function(reduce_retracing=True, + #jit_compile=True, (ScaleAndTranslate is not supported by XLA) + input_signature=[tf.TensorSpec([1, None, None, 3], + dtype=tf.float32)]) def call(self, img, training=False): height = tf.shape(img)[1] width = tf.shape(img)[2] @@ -68,12 +72,8 @@ class wrap_layout_model_resized(models.Model): (height, width)) return pred - @tf.function(reduce_retracing=True, - #jit_compile=True, (ScaleAndTranslate is not supported by XLA) - input_signature=[tf.TensorSpec([1, None, None, 3], - dtype=tf.float32)]) - def predict(self, x): - return self(x) + def predict(self, x, verbose=0): + return self(x).numpy() class wrap_layout_model_patched(models.Model): """ @@ -98,6 +98,10 @@ class wrap_layout_model_patched(models.Model): self.height, self.width) self.window = tf.expand_dims(window, axis=0) + @tf.function(reduce_retracing=True, + #jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA) + input_signature=[tf.TensorSpec([1, None, None, 3], + dtype=tf.float32)]) def call(self, img, training=False): height = tf.shape(img)[1] width = tf.shape(img)[2] @@ -152,9 +156,5 @@ class wrap_layout_model_patched(models.Model): pred = tf.expand_dims(pred, axis=0) return pred - @tf.function(reduce_retracing=True, - #jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA) - input_signature=[tf.TensorSpec([1, None, None, 3], - dtype=tf.float32)]) - def predict(self, x): - return self(x) + def predict(self, x, verbose=0): + return self(x).numpy()