wrap_layout_model_resized/patched: compile call instead of predict

(so `predict()` can directly convert back to Numpy)
This commit is contained in:
Robert Sachunsky 2026-03-07 03:52:14 +01:00
parent 338c4a0edf
commit 6f4ec53f7e
2 changed files with 13 additions and 18 deletions

View file

@ -104,11 +104,6 @@ DPI_THRESHOLD = 298
MAX_SLOPE = 999
KERNEL = np.ones((5, 5), np.uint8)
projection_dim = 64
patch_size = 1
num_patches =21*21#14*14#28*28#14*14#28*28
class Eynollah:
def __init__(
@ -925,7 +920,7 @@ class Eynollah:
self.logger.debug("enter do_prediction_new_concept (%s)", model.name)
img = img / 255.0
prediction = model.predict(img[np.newaxis]).numpy()[0]
prediction = model.predict(img[np.newaxis])[0]
confidence = prediction[:, :, 1]
segmentation = np.argmax(prediction, axis=2).astype(np.uint8)

View file

@ -57,6 +57,10 @@ class wrap_layout_model_resized(models.Model):
self.height = model.layers[-1].output_shape[1]
self.width = model.layers[-1].output_shape[2]
@tf.function(reduce_retracing=True,
#jit_compile=True, (ScaleAndTranslate is not supported by XLA)
input_signature=[tf.TensorSpec([1, None, None, 3],
dtype=tf.float32)])
def call(self, img, training=False):
height = tf.shape(img)[1]
width = tf.shape(img)[2]
@ -68,12 +72,8 @@ class wrap_layout_model_resized(models.Model):
(height, width))
return pred
@tf.function(reduce_retracing=True,
#jit_compile=True, (ScaleAndTranslate is not supported by XLA)
input_signature=[tf.TensorSpec([1, None, None, 3],
dtype=tf.float32)])
def predict(self, x):
return self(x)
def predict(self, x, verbose=0):
return self(x).numpy()
class wrap_layout_model_patched(models.Model):
"""
@ -98,6 +98,10 @@ class wrap_layout_model_patched(models.Model):
self.height, self.width)
self.window = tf.expand_dims(window, axis=0)
@tf.function(reduce_retracing=True,
#jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA)
input_signature=[tf.TensorSpec([1, None, None, 3],
dtype=tf.float32)])
def call(self, img, training=False):
height = tf.shape(img)[1]
width = tf.shape(img)[2]
@ -152,9 +156,5 @@ class wrap_layout_model_patched(models.Model):
pred = tf.expand_dims(pred, axis=0)
return pred
@tf.function(reduce_retracing=True,
#jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA)
input_signature=[tf.TensorSpec([1, None, None, 3],
dtype=tf.float32)])
def predict(self, x):
return self(x)
def predict(self, x, verbose=0):
return self(x).numpy()