diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index f74e9e1..d1ba4ee 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -176,15 +176,14 @@ class sbb_predict: session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() tensorflow_backend.set_session(session) - - ##if self.weights_dir!=None: - ##self.model.load_weights(self.weights_dir) + self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) + + if self.task != 'classification' and self.task != 'reading_order': + self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] + self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] + self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] assert isinstance(self.model, Model) - if self.task != 'classification' and self.task != 'reading_order': - self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] - self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] - self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: if task == "binarization":