mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
eynollah_ocr: make work again, re-use Eynollah base class…
- re-use Eynollah base class - use `ModelZoo.load_models()` instead of `load_model()` - pass in `device` init kwarg, delegate to `ModelZoo.load_models()` - `device`: return Torch device at loaded model tensors instead of ad-hoc selection - make numeric init kwargs non-optional (only numeric)
This commit is contained in:
parent
ded668a256
commit
cd62f13872
2 changed files with 32 additions and 26 deletions
|
|
@ -66,6 +66,10 @@ import click
|
|||
"--min_conf_value_of_textline_text",
|
||||
"-min_conf",
|
||||
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
|
||||
@click.option(
|
||||
"--device",
|
||||
"-D",
|
||||
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||
)
|
||||
@click.pass_context
|
||||
def ocr_cli(
|
||||
|
|
@ -81,18 +85,20 @@ def ocr_cli(
|
|||
do_not_mask_with_textline_contour,
|
||||
batch_size,
|
||||
min_conf_value_of_textline_text,
|
||||
device,
|
||||
):
|
||||
"""
|
||||
Recognize text with a CNN/RNN or transformer ML model.
|
||||
"""
|
||||
assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
|
||||
assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
|
||||
from ..eynollah_ocr import Eynollah_ocr
|
||||
eynollah_ocr = Eynollah_ocr(
|
||||
model_zoo=ctx.obj.model_zoo,
|
||||
tr_ocr=tr_ocr,
|
||||
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
|
||||
batch_size=batch_size,
|
||||
min_conf_value_of_textline_text=min_conf_value_of_textline_text)
|
||||
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
|
||||
device=device)
|
||||
eynollah_ocr.run(overwrite=overwrite,
|
||||
dir_in=dir_in,
|
||||
dir_in_bin=dir_in_bin,
|
||||
|
|
|
|||
|
|
@ -14,16 +14,17 @@ from cv2.typing import MatLike
|
|||
from xml.etree import ElementTree as ET
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
from eynollah.model_zoo import EynollahModelZoo
|
||||
from eynollah.utils.font import get_font
|
||||
from eynollah.utils.xml import etree_namespace_for_element_tag
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
|
||||
from .eynollah import Eynollah
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils import is_image_filename
|
||||
from .utils.font import get_font
|
||||
from .utils.xml import etree_namespace_for_element_tag
|
||||
from .utils.resize import resize_image
|
||||
from .utils.utils_ocr import (
|
||||
break_curved_line_into_small_pieces_and_then_merge,
|
||||
|
|
@ -44,45 +45,44 @@ class EynollahOcrResult:
|
|||
cropped_lines_region_indexer: List
|
||||
total_bb_coordinates:List
|
||||
|
||||
class Eynollah_ocr:
|
||||
class Eynollah_ocr(Eynollah):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_zoo: EynollahModelZoo,
|
||||
tr_ocr=False,
|
||||
batch_size: Optional[int]=None,
|
||||
batch_size: int=0,
|
||||
do_not_mask_with_textline_contour: bool=False,
|
||||
min_conf_value_of_textline_text : Optional[float]=None,
|
||||
min_conf_value_of_textline_text : float=0.3,
|
||||
logger: Optional[Logger]=None,
|
||||
device: str = '',
|
||||
):
|
||||
self.tr_ocr = tr_ocr
|
||||
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
|
||||
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
|
||||
self.logger = logger if logger else getLogger('eynollah.ocr')
|
||||
self.model_zoo = model_zoo
|
||||
|
||||
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3
|
||||
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
|
||||
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
|
||||
self.b_s = batch_size or 2 if tr_ocr else 8
|
||||
|
||||
if tr_ocr:
|
||||
self.model_zoo.load_model('trocr_processor')
|
||||
self.model_zoo.load_model('ocr', 'tr')
|
||||
self.model_zoo.get('ocr').to(self.device)
|
||||
self.model_zoo = model_zoo
|
||||
self.setup_models(device=device)
|
||||
|
||||
def setup_models(self, device=''):
|
||||
if self.tr_ocr:
|
||||
self.model_zoo.load_models('trocr_processor',
|
||||
('ocr', 'tr'),
|
||||
device=device)
|
||||
else:
|
||||
self.model_zoo.load_model('ocr', '')
|
||||
self.model_zoo.load_model('num_to_char')
|
||||
self.model_zoo.load_model('characters')
|
||||
self.end_character = len(self.model_zoo.get('characters', list)) + 2
|
||||
self.model_zoo.load_models('ocr',
|
||||
'num_to_char',
|
||||
'characters',
|
||||
device=device)
|
||||
self.end_character = len(self.model_zoo.get('characters')) + 2
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
assert torch
|
||||
if torch.cuda.is_available():
|
||||
self.logger.info("Using GPU acceleration")
|
||||
return torch.device("cuda:0")
|
||||
else:
|
||||
self.logger.info("Using CPU processing")
|
||||
return torch.device("cpu")
|
||||
return self.model_zoo.get('ocr').device
|
||||
|
||||
def run_trocr(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue