eynollah_ocr: make work again, re-use Eynollah base class…

- re-use Eynollah base class
- use `ModelZoo.load_models()` instead of `load_model()`
- pass in `device` init kwarg, delegate to `ModelZoo.load_models()`
- `device`: return Torch device at loaded model tensors
  instead of ad-hoc selection
- make numeric init kwargs non-optional (only numeric)
This commit is contained in:
Robert Sachunsky 2026-05-12 18:31:18 +02:00
parent ded668a256
commit cd62f13872
2 changed files with 32 additions and 26 deletions

View file

@ -66,6 +66,10 @@ import click
"--min_conf_value_of_textline_text",
"-min_conf",
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
@click.option(
"--device",
"-D",
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
)
@click.pass_context
def ocr_cli(
@ -81,18 +85,20 @@ def ocr_cli(
do_not_mask_with_textline_contour,
batch_size,
min_conf_value_of_textline_text,
device,
):
"""
Recognize text with a CNN/RNN or transformer ML model.
"""
assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
from ..eynollah_ocr import Eynollah_ocr
eynollah_ocr = Eynollah_ocr(
model_zoo=ctx.obj.model_zoo,
tr_ocr=tr_ocr,
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
batch_size=batch_size,
min_conf_value_of_textline_text=min_conf_value_of_textline_text)
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
device=device)
eynollah_ocr.run(overwrite=overwrite,
dir_in=dir_in,
dir_in_bin=dir_in_bin,

View file

@ -14,16 +14,17 @@ from cv2.typing import MatLike
from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw
import numpy as np
from eynollah.model_zoo import EynollahModelZoo
from eynollah.utils.font import get_font
from eynollah.utils.xml import etree_namespace_for_element_tag
try:
import torch
except ImportError:
torch = None
from .eynollah import Eynollah
from .model_zoo import EynollahModelZoo
from .utils import is_image_filename
from .utils.font import get_font
from .utils.xml import etree_namespace_for_element_tag
from .utils.resize import resize_image
from .utils.utils_ocr import (
break_curved_line_into_small_pieces_and_then_merge,
@ -44,45 +45,44 @@ class EynollahOcrResult:
cropped_lines_region_indexer: List
total_bb_coordinates:List
class Eynollah_ocr:
class Eynollah_ocr(Eynollah):
def __init__(
self,
*,
model_zoo: EynollahModelZoo,
tr_ocr=False,
batch_size: Optional[int]=None,
batch_size: int=0,
do_not_mask_with_textline_contour: bool=False,
min_conf_value_of_textline_text : Optional[float]=None,
min_conf_value_of_textline_text : float=0.3,
logger: Optional[Logger]=None,
device: str = '',
):
self.tr_ocr = tr_ocr
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
self.logger = logger if logger else getLogger('eynollah.ocr')
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
self.b_s = batch_size or 2 if tr_ocr else 8
self.model_zoo = model_zoo
self.setup_models(device=device)
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
if tr_ocr:
self.model_zoo.load_model('trocr_processor')
self.model_zoo.load_model('ocr', 'tr')
self.model_zoo.get('ocr').to(self.device)
def setup_models(self, device=''):
if self.tr_ocr:
self.model_zoo.load_models('trocr_processor',
('ocr', 'tr'),
device=device)
else:
self.model_zoo.load_model('ocr', '')
self.model_zoo.load_model('num_to_char')
self.model_zoo.load_model('characters')
self.end_character = len(self.model_zoo.get('characters', list)) + 2
self.model_zoo.load_models('ocr',
'num_to_char',
'characters',
device=device)
self.end_character = len(self.model_zoo.get('characters')) + 2
@property
def device(self):
assert torch
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
return torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
return torch.device("cpu")
return self.model_zoo.get('ocr').device
def run_trocr(
self,