mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
Update inference.py
Fix broken inference model loading introduced during refactoring or merge
This commit is contained in:
parent
47fa22112c
commit
ed034aa8ce
1 changed files with 6 additions and 7 deletions
|
|
@ -176,15 +176,14 @@ class sbb_predict:
|
|||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(session)
|
||||
|
||||
|
||||
##if self.weights_dir!=None:
|
||||
##self.model.load_weights(self.weights_dir)
|
||||
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
|
||||
if self.task != 'classification' and self.task != 'reading_order':
|
||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||
|
||||
assert isinstance(self.model, Model)
|
||||
if self.task != 'classification' and self.task != 'reading_order':
|
||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||
|
||||
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
|
||||
if task == "binarization":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue