diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 76e54a7..9ed02d4 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -166,7 +166,8 @@ class Eynollah_ocr(Eynollah): page_tree: ET.ElementTree, page_ns, ) -> EynollahOcrResult: - _, image_height, image_width, _ = self.model_zoo.get('ocr').input_shape + input_shape, _ = self.model_zoo.get('ocr').input_shape + _, image_height, image_width, _ = input_shape total_bb_coordinates = [] cropped_lines_rgb = [] diff --git a/tests/test_model_zoo.py b/tests/test_model_zoo.py index 2902bfe..082f2a7 100644 --- a/tests/test_model_zoo.py +++ b/tests/test_model_zoo.py @@ -23,6 +23,7 @@ def test_cnnrnnocr1( model = model_zoo.get('ocr') assert isinstance(model, Predictor) shape = model.input_shape - assert len(shape) == 4 + assert len(shape) == 2 + assert len(shape[0]) == 4 except ImportError: pass