From 294b6356d3b233fa80322e5b259c31ad7038cd6d Mon Sep 17 00:00:00 2001 From: kba Date: Mon, 27 Oct 2025 11:45:16 +0100 Subject: [PATCH] wip --- Makefile | 1 - src/eynollah/eynollah.py | 13 +++++++++++-- src/eynollah/eynollah_ocr.py | 4 ++++ src/eynollah/model_zoo/__init__.py | 5 ++++- src/eynollah/model_zoo/model_zoo.py | 8 +++----- tests/test_model_zoo.py | 19 +++++++++++++++++++ 6 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 tests/test_model_zoo.py diff --git a/Makefile b/Makefile index 1e7f2dd..4fcd9fb 100644 --- a/Makefile +++ b/Makefile @@ -73,7 +73,6 @@ install-dev: deps-test: $(EYNOLLAH_MODELS_ZIP) $(PIP) install -r requirements-test.txt -endif smoke-test: TMPDIR != mktemp -d smoke-test: tests/resources/kant_aufklaerung_1784_0020.tif diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 232631a..98e894c 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -45,7 +45,7 @@ import tensorflow as tf tf.get_logger().setLevel("ERROR") warnings.filterwarnings("ignore") -from .model_zoo import EynollahModelZoo +from .model_zoo import (EynollahModelZoo, KerasModel, TrOCRProcessor) from .utils.contour import ( filter_contours_area_of_image, filter_contours_area_of_image_tables, @@ -178,6 +178,7 @@ class Eynollah: self.full_layout = full_layout self.tables = tables self.right2left = right2left + # --input-binary sensible if image is very dark, if layout is not working. self.input_binary = input_binary self.allow_scaling = allow_scaling self.headers_off = headers_off @@ -3651,7 +3652,15 @@ class Eynollah: pass def return_ocr_of_textline_without_common_section( - self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot): + self, + textline_image, + model_ocr: KerasModel, + processor: TrOCRProcessor, + device, + width_textline, + h2w_ratio, + ind_tot, + ): if h2w_ratio > 0.05: pixel_values = processor(textline_image, return_tensors="pt").pixel_values diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index cfd410c..41643de 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -63,8 +63,11 @@ class Eynollah_ocr: logger: Optional[Logger]=None, ): self.tr_ocr = tr_ocr + # For generating textline-image pairs for traning, move to generate_gt_for_training self.export_textline_images_and_text = export_textline_images_and_text + # masking for OCR and GT generation, relevant for skewed lines and bounding boxes self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour + # prefix or dataset self.pref_of_dataset = pref_of_dataset self.logger = logger if logger else getLogger('eynollah') self.model_zoo = EynollahModelZoo(basedir=dir_models) @@ -103,6 +106,7 @@ class Eynollah_ocr: def run(self, overwrite: bool = False, dir_in: Optional[str] = None, + # Prediction with RGB and binarized images for selected pages, should not be the default dir_in_bin: Optional[str] = None, image_filename: Optional[str] = None, dir_xmls: Optional[str] = None, diff --git a/src/eynollah/model_zoo/__init__.py b/src/eynollah/model_zoo/__init__.py index e1dc985..dda52c2 100644 --- a/src/eynollah/model_zoo/__init__.py +++ b/src/eynollah/model_zoo/__init__.py @@ -1,4 +1,7 @@ __all__ = [ 'EynollahModelZoo', + 'KerasModel', + 'TrOCRProcessor', + 'VisionEncoderDecoderModel', ] -from .model_zoo import EynollahModelZoo +from .model_zoo import EynollahModelZoo, KerasModel, TrOCRProcessor, VisionEncoderDecoderModel diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 8948a1f..dada98f 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -9,7 +9,6 @@ from keras.models import Model as KerasModel from keras.models import load_model from tabulate import tabulate from transformers import TrOCRProcessor, VisionEncoderDecoderModel - from ..patch_encoder import PatchEncoder, Patches from .specs import EynollahModelSpecSet from .default_specs import DEFAULT_MODEL_SPECS @@ -100,7 +99,7 @@ class EynollahModelZoo: elif model_category == 'characters': model = self._load_characters() elif model_category == 'trocr_processor': - return TrOCRProcessor.from_pretrained(self.model_path(...)) + model = TrOCRProcessor.from_pretrained(model_path) else: try: model = load_model(model_path, compile=False) @@ -184,6 +183,5 @@ class EynollahModelZoo: Ensure that a loaded models is not referenced by ``self._loaded`` anymore """ if hasattr(self, '_loaded') and getattr(self, '_loaded'): - for needle in self._loaded: - if self._loaded[needle]: - del self._loaded[needle] + for needle in self._loaded.keys(): + del self._loaded[needle] diff --git a/tests/test_model_zoo.py b/tests/test_model_zoo.py new file mode 100644 index 0000000..81e84f6 --- /dev/null +++ b/tests/test_model_zoo.py @@ -0,0 +1,19 @@ +from pathlib import Path + +from eynollah.model_zoo import EynollahModelZoo, TrOCRProcessor, VisionEncoderDecoderModel + +testdir = Path(__file__).parent.resolve() +MODELS_DIR = testdir.parent + +def test_trocr1(): + model_zoo = EynollahModelZoo(str(MODELS_DIR)) + model_zoo.load_model('trocr_processor') + proc = model_zoo.get('trocr_processor', TrOCRProcessor) + assert isinstance(proc, TrOCRProcessor) + + model_zoo.load_model('ocr', 'tr') + model = model_zoo.get('ocr') + assert isinstance(model, VisionEncoderDecoderModel) + print(proc) + +test_trocr1()