Update inference.py

Fix broken inference model loading introduced during refactoring or merge
This commit is contained in:
vahidrezanezhad 2026-02-12 15:28:15 +01:00 committed by GitHub
parent 47fa22112c
commit ed034aa8ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -176,15 +176,14 @@ class sbb_predict:
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization":