From cd62f13872419deb3c8740e8d5ded6a21cdec3c9 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Tue, 12 May 2026 18:31:18 +0200 Subject: [PATCH] =?UTF-8?q?eynollah=5Focr:=20make=20work=20again,=20re-use?= =?UTF-8?q?=20Eynollah=20base=20class=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - re-use Eynollah base class - use `ModelZoo.load_models()` instead of `load_model()` - pass in `device` init kwarg, delegate to `ModelZoo.load_models()` - `device`: return Torch device at loaded model tensors instead of ad-hoc selection - make numeric init kwargs non-optional (only numeric) --- src/eynollah/cli/cli_ocr.py | 10 ++++++-- src/eynollah/eynollah_ocr.py | 48 ++++++++++++++++++------------------ 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/eynollah/cli/cli_ocr.py b/src/eynollah/cli/cli_ocr.py index 406af61..f9b74c8 100644 --- a/src/eynollah/cli/cli_ocr.py +++ b/src/eynollah/cli/cli_ocr.py @@ -66,6 +66,10 @@ import click "--min_conf_value_of_textline_text", "-min_conf", help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.", +@click.option( + "--device", + "-D", + help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')", ) @click.pass_context def ocr_cli( @@ -81,18 +85,20 @@ def ocr_cli( do_not_mask_with_textline_contour, batch_size, min_conf_value_of_textline_text, + device, ): """ Recognize text with a CNN/RNN or transformer ML model. """ - assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both." + assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both." from ..eynollah_ocr import Eynollah_ocr eynollah_ocr = Eynollah_ocr( model_zoo=ctx.obj.model_zoo, tr_ocr=tr_ocr, do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, batch_size=batch_size, - min_conf_value_of_textline_text=min_conf_value_of_textline_text) + min_conf_value_of_textline_text=min_conf_value_of_textline_text, + device=device) eynollah_ocr.run(overwrite=overwrite, dir_in=dir_in, dir_in_bin=dir_in_bin, diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 3c918e5..4470671 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -14,16 +14,17 @@ from cv2.typing import MatLike from xml.etree import ElementTree as ET from PIL import Image, ImageDraw import numpy as np -from eynollah.model_zoo import EynollahModelZoo -from eynollah.utils.font import get_font -from eynollah.utils.xml import etree_namespace_for_element_tag try: import torch except ImportError: torch = None +from .eynollah import Eynollah +from .model_zoo import EynollahModelZoo from .utils import is_image_filename +from .utils.font import get_font +from .utils.xml import etree_namespace_for_element_tag from .utils.resize import resize_image from .utils.utils_ocr import ( break_curved_line_into_small_pieces_and_then_merge, @@ -44,45 +45,44 @@ class EynollahOcrResult: cropped_lines_region_indexer: List total_bb_coordinates:List -class Eynollah_ocr: +class Eynollah_ocr(Eynollah): def __init__( self, *, model_zoo: EynollahModelZoo, tr_ocr=False, - batch_size: Optional[int]=None, + batch_size: int=0, do_not_mask_with_textline_contour: bool=False, - min_conf_value_of_textline_text : Optional[float]=None, + min_conf_value_of_textline_text : float=0.3, logger: Optional[Logger]=None, + device: str = '', ): self.tr_ocr = tr_ocr # masking for OCR and GT generation, relevant for skewed lines and bounding boxes self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour self.logger = logger if logger else getLogger('eynollah.ocr') - self.model_zoo = model_zoo - self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3 - self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size + self.min_conf_value_of_textline_text = min_conf_value_of_textline_text + self.b_s = batch_size or 2 if tr_ocr else 8 - if tr_ocr: - self.model_zoo.load_model('trocr_processor') - self.model_zoo.load_model('ocr', 'tr') - self.model_zoo.get('ocr').to(self.device) + self.model_zoo = model_zoo + self.setup_models(device=device) + + def setup_models(self, device=''): + if self.tr_ocr: + self.model_zoo.load_models('trocr_processor', + ('ocr', 'tr'), + device=device) else: - self.model_zoo.load_model('ocr', '') - self.model_zoo.load_model('num_to_char') - self.model_zoo.load_model('characters') - self.end_character = len(self.model_zoo.get('characters', list)) + 2 + self.model_zoo.load_models('ocr', + 'num_to_char', + 'characters', + device=device) + self.end_character = len(self.model_zoo.get('characters')) + 2 @property def device(self): - assert torch - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - return torch.device("cuda:0") - else: - self.logger.info("Using CPU processing") - return torch.device("cpu") + return self.model_zoo.get('ocr').device def run_trocr( self,