diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 79d3573..fffd389 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -191,6 +191,9 @@ class EynollahModelZoo: elif patched: model = wrap_layout_model_patched(model) model._name = model_category + '_patched' + else: + model.jit_compile = True + model.make_predict_function() return model def get(self, model_category: str) -> Predictor: diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py index d6a74ea..f163132 100644 --- a/src/eynollah/patch_encoder.py +++ b/src/eynollah/patch_encoder.py @@ -72,9 +72,6 @@ class wrap_layout_model_resized(models.Model): (height, width)) return pred - def predict(self, x, verbose=0): - return self(x).numpy() - class wrap_layout_model_patched(models.Model): """ replacement for layout model using sliding window for patches @@ -157,6 +154,3 @@ class wrap_layout_model_patched(models.Model): (height, width, self.classes)) pred = tf.expand_dims(pred, axis=0) return pred - - def predict(self, x, verbose=0): - return self(x).numpy()