eynollah_ocr: make work again, re-use Eynollah base class…

- re-use Eynollah base class
- use `ModelZoo.load_models()` instead of `load_model()`
- pass in `device` init kwarg, delegate to `ModelZoo.load_models()`
- `device`: return Torch device at loaded model tensors
  instead of ad-hoc selection
- make numeric init kwargs non-optional (only numeric)
This commit is contained in:
Robert Sachunsky 2026-05-12 18:31:18 +02:00
parent ded668a256
commit cd62f13872
2 changed files with 32 additions and 26 deletions

View file

@ -66,6 +66,10 @@ import click
"--min_conf_value_of_textline_text", "--min_conf_value_of_textline_text",
"-min_conf", "-min_conf",
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.", help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
@click.option(
"--device",
"-D",
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
) )
@click.pass_context @click.pass_context
def ocr_cli( def ocr_cli(
@ -81,18 +85,20 @@ def ocr_cli(
do_not_mask_with_textline_contour, do_not_mask_with_textline_contour,
batch_size, batch_size,
min_conf_value_of_textline_text, min_conf_value_of_textline_text,
device,
): ):
""" """
Recognize text with a CNN/RNN or transformer ML model. Recognize text with a CNN/RNN or transformer ML model.
""" """
assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both." assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
from ..eynollah_ocr import Eynollah_ocr from ..eynollah_ocr import Eynollah_ocr
eynollah_ocr = Eynollah_ocr( eynollah_ocr = Eynollah_ocr(
model_zoo=ctx.obj.model_zoo, model_zoo=ctx.obj.model_zoo,
tr_ocr=tr_ocr, tr_ocr=tr_ocr,
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
batch_size=batch_size, batch_size=batch_size,
min_conf_value_of_textline_text=min_conf_value_of_textline_text) min_conf_value_of_textline_text=min_conf_value_of_textline_text,
device=device)
eynollah_ocr.run(overwrite=overwrite, eynollah_ocr.run(overwrite=overwrite,
dir_in=dir_in, dir_in=dir_in,
dir_in_bin=dir_in_bin, dir_in_bin=dir_in_bin,

View file

@ -14,16 +14,17 @@ from cv2.typing import MatLike
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
import numpy as np import numpy as np
from eynollah.model_zoo import EynollahModelZoo
from eynollah.utils.font import get_font
from eynollah.utils.xml import etree_namespace_for_element_tag
try: try:
import torch import torch
except ImportError: except ImportError:
torch = None torch = None
from .eynollah import Eynollah
from .model_zoo import EynollahModelZoo
from .utils import is_image_filename from .utils import is_image_filename
from .utils.font import get_font
from .utils.xml import etree_namespace_for_element_tag
from .utils.resize import resize_image from .utils.resize import resize_image
from .utils.utils_ocr import ( from .utils.utils_ocr import (
break_curved_line_into_small_pieces_and_then_merge, break_curved_line_into_small_pieces_and_then_merge,
@ -44,45 +45,44 @@ class EynollahOcrResult:
cropped_lines_region_indexer: List cropped_lines_region_indexer: List
total_bb_coordinates:List total_bb_coordinates:List
class Eynollah_ocr: class Eynollah_ocr(Eynollah):
def __init__( def __init__(
self, self,
*, *,
model_zoo: EynollahModelZoo, model_zoo: EynollahModelZoo,
tr_ocr=False, tr_ocr=False,
batch_size: Optional[int]=None, batch_size: int=0,
do_not_mask_with_textline_contour: bool=False, do_not_mask_with_textline_contour: bool=False,
min_conf_value_of_textline_text : Optional[float]=None, min_conf_value_of_textline_text : float=0.3,
logger: Optional[Logger]=None, logger: Optional[Logger]=None,
device: str = '',
): ):
self.tr_ocr = tr_ocr self.tr_ocr = tr_ocr
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes # masking for OCR and GT generation, relevant for skewed lines and bounding boxes
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
self.logger = logger if logger else getLogger('eynollah.ocr') self.logger = logger if logger else getLogger('eynollah.ocr')
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
self.b_s = batch_size or 2 if tr_ocr else 8
self.model_zoo = model_zoo self.model_zoo = model_zoo
self.setup_models(device=device)
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3 def setup_models(self, device=''):
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size if self.tr_ocr:
self.model_zoo.load_models('trocr_processor',
if tr_ocr: ('ocr', 'tr'),
self.model_zoo.load_model('trocr_processor') device=device)
self.model_zoo.load_model('ocr', 'tr')
self.model_zoo.get('ocr').to(self.device)
else: else:
self.model_zoo.load_model('ocr', '') self.model_zoo.load_models('ocr',
self.model_zoo.load_model('num_to_char') 'num_to_char',
self.model_zoo.load_model('characters') 'characters',
self.end_character = len(self.model_zoo.get('characters', list)) + 2 device=device)
self.end_character = len(self.model_zoo.get('characters')) + 2
@property @property
def device(self): def device(self):
assert torch return self.model_zoo.get('ocr').device
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
return torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
return torch.device("cpu")
def run_trocr( def run_trocr(
self, self,