mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-24 08:02:45 +01:00
wrap_layout_model_resized/patched: compile call instead of predict
(so `predict()` can directly convert back to Numpy)
This commit is contained in:
parent
338c4a0edf
commit
6f4ec53f7e
2 changed files with 13 additions and 18 deletions
|
|
@ -104,11 +104,6 @@ DPI_THRESHOLD = 298
|
|||
MAX_SLOPE = 999
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
||||
projection_dim = 64
|
||||
patch_size = 1
|
||||
num_patches =21*21#14*14#28*28#14*14#28*28
|
||||
|
||||
|
||||
|
||||
class Eynollah:
|
||||
def __init__(
|
||||
|
|
@ -925,7 +920,7 @@ class Eynollah:
|
|||
self.logger.debug("enter do_prediction_new_concept (%s)", model.name)
|
||||
img = img / 255.0
|
||||
|
||||
prediction = model.predict(img[np.newaxis]).numpy()[0]
|
||||
prediction = model.predict(img[np.newaxis])[0]
|
||||
confidence = prediction[:, :, 1]
|
||||
segmentation = np.argmax(prediction, axis=2).astype(np.uint8)
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ class wrap_layout_model_resized(models.Model):
|
|||
self.height = model.layers[-1].output_shape[1]
|
||||
self.width = model.layers[-1].output_shape[2]
|
||||
|
||||
@tf.function(reduce_retracing=True,
|
||||
#jit_compile=True, (ScaleAndTranslate is not supported by XLA)
|
||||
input_signature=[tf.TensorSpec([1, None, None, 3],
|
||||
dtype=tf.float32)])
|
||||
def call(self, img, training=False):
|
||||
height = tf.shape(img)[1]
|
||||
width = tf.shape(img)[2]
|
||||
|
|
@ -68,12 +72,8 @@ class wrap_layout_model_resized(models.Model):
|
|||
(height, width))
|
||||
return pred
|
||||
|
||||
@tf.function(reduce_retracing=True,
|
||||
#jit_compile=True, (ScaleAndTranslate is not supported by XLA)
|
||||
input_signature=[tf.TensorSpec([1, None, None, 3],
|
||||
dtype=tf.float32)])
|
||||
def predict(self, x):
|
||||
return self(x)
|
||||
def predict(self, x, verbose=0):
|
||||
return self(x).numpy()
|
||||
|
||||
class wrap_layout_model_patched(models.Model):
|
||||
"""
|
||||
|
|
@ -98,6 +98,10 @@ class wrap_layout_model_patched(models.Model):
|
|||
self.height, self.width)
|
||||
self.window = tf.expand_dims(window, axis=0)
|
||||
|
||||
@tf.function(reduce_retracing=True,
|
||||
#jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA)
|
||||
input_signature=[tf.TensorSpec([1, None, None, 3],
|
||||
dtype=tf.float32)])
|
||||
def call(self, img, training=False):
|
||||
height = tf.shape(img)[1]
|
||||
width = tf.shape(img)[2]
|
||||
|
|
@ -152,9 +156,5 @@ class wrap_layout_model_patched(models.Model):
|
|||
pred = tf.expand_dims(pred, axis=0)
|
||||
return pred
|
||||
|
||||
@tf.function(reduce_retracing=True,
|
||||
#jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA)
|
||||
input_signature=[tf.TensorSpec([1, None, None, 3],
|
||||
dtype=tf.float32)])
|
||||
def predict(self, x):
|
||||
return self(x)
|
||||
def predict(self, x, verbose=0):
|
||||
return self(x).numpy()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue