mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
eynollah_ocr: make work again, re-use Eynollah base class…
- re-use Eynollah base class - use `ModelZoo.load_models()` instead of `load_model()` - pass in `device` init kwarg, delegate to `ModelZoo.load_models()` - `device`: return Torch device at loaded model tensors instead of ad-hoc selection - make numeric init kwargs non-optional (only numeric)
This commit is contained in:
parent
ded668a256
commit
cd62f13872
2 changed files with 32 additions and 26 deletions
|
|
@ -66,6 +66,10 @@ import click
|
||||||
"--min_conf_value_of_textline_text",
|
"--min_conf_value_of_textline_text",
|
||||||
"-min_conf",
|
"-min_conf",
|
||||||
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
|
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
|
||||||
|
@click.option(
|
||||||
|
"--device",
|
||||||
|
"-D",
|
||||||
|
help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')",
|
||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def ocr_cli(
|
def ocr_cli(
|
||||||
|
|
@ -81,18 +85,20 @@ def ocr_cli(
|
||||||
do_not_mask_with_textline_contour,
|
do_not_mask_with_textline_contour,
|
||||||
batch_size,
|
batch_size,
|
||||||
min_conf_value_of_textline_text,
|
min_conf_value_of_textline_text,
|
||||||
|
device,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Recognize text with a CNN/RNN or transformer ML model.
|
Recognize text with a CNN/RNN or transformer ML model.
|
||||||
"""
|
"""
|
||||||
assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
|
assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both."
|
||||||
from ..eynollah_ocr import Eynollah_ocr
|
from ..eynollah_ocr import Eynollah_ocr
|
||||||
eynollah_ocr = Eynollah_ocr(
|
eynollah_ocr = Eynollah_ocr(
|
||||||
model_zoo=ctx.obj.model_zoo,
|
model_zoo=ctx.obj.model_zoo,
|
||||||
tr_ocr=tr_ocr,
|
tr_ocr=tr_ocr,
|
||||||
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
|
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
min_conf_value_of_textline_text=min_conf_value_of_textline_text)
|
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
|
||||||
|
device=device)
|
||||||
eynollah_ocr.run(overwrite=overwrite,
|
eynollah_ocr.run(overwrite=overwrite,
|
||||||
dir_in=dir_in,
|
dir_in=dir_in,
|
||||||
dir_in_bin=dir_in_bin,
|
dir_in_bin=dir_in_bin,
|
||||||
|
|
|
||||||
|
|
@ -14,16 +14,17 @@ from cv2.typing import MatLike
|
||||||
from xml.etree import ElementTree as ET
|
from xml.etree import ElementTree as ET
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from eynollah.model_zoo import EynollahModelZoo
|
|
||||||
from eynollah.utils.font import get_font
|
|
||||||
from eynollah.utils.xml import etree_namespace_for_element_tag
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
torch = None
|
torch = None
|
||||||
|
|
||||||
|
|
||||||
|
from .eynollah import Eynollah
|
||||||
|
from .model_zoo import EynollahModelZoo
|
||||||
from .utils import is_image_filename
|
from .utils import is_image_filename
|
||||||
|
from .utils.font import get_font
|
||||||
|
from .utils.xml import etree_namespace_for_element_tag
|
||||||
from .utils.resize import resize_image
|
from .utils.resize import resize_image
|
||||||
from .utils.utils_ocr import (
|
from .utils.utils_ocr import (
|
||||||
break_curved_line_into_small_pieces_and_then_merge,
|
break_curved_line_into_small_pieces_and_then_merge,
|
||||||
|
|
@ -44,45 +45,44 @@ class EynollahOcrResult:
|
||||||
cropped_lines_region_indexer: List
|
cropped_lines_region_indexer: List
|
||||||
total_bb_coordinates:List
|
total_bb_coordinates:List
|
||||||
|
|
||||||
class Eynollah_ocr:
|
class Eynollah_ocr(Eynollah):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
model_zoo: EynollahModelZoo,
|
model_zoo: EynollahModelZoo,
|
||||||
tr_ocr=False,
|
tr_ocr=False,
|
||||||
batch_size: Optional[int]=None,
|
batch_size: int=0,
|
||||||
do_not_mask_with_textline_contour: bool=False,
|
do_not_mask_with_textline_contour: bool=False,
|
||||||
min_conf_value_of_textline_text : Optional[float]=None,
|
min_conf_value_of_textline_text : float=0.3,
|
||||||
logger: Optional[Logger]=None,
|
logger: Optional[Logger]=None,
|
||||||
|
device: str = '',
|
||||||
):
|
):
|
||||||
self.tr_ocr = tr_ocr
|
self.tr_ocr = tr_ocr
|
||||||
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
|
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
|
||||||
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
|
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
|
||||||
self.logger = logger if logger else getLogger('eynollah.ocr')
|
self.logger = logger if logger else getLogger('eynollah.ocr')
|
||||||
|
|
||||||
|
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text
|
||||||
|
self.b_s = batch_size or 2 if tr_ocr else 8
|
||||||
|
|
||||||
self.model_zoo = model_zoo
|
self.model_zoo = model_zoo
|
||||||
|
self.setup_models(device=device)
|
||||||
|
|
||||||
self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3
|
def setup_models(self, device=''):
|
||||||
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
|
if self.tr_ocr:
|
||||||
|
self.model_zoo.load_models('trocr_processor',
|
||||||
if tr_ocr:
|
('ocr', 'tr'),
|
||||||
self.model_zoo.load_model('trocr_processor')
|
device=device)
|
||||||
self.model_zoo.load_model('ocr', 'tr')
|
|
||||||
self.model_zoo.get('ocr').to(self.device)
|
|
||||||
else:
|
else:
|
||||||
self.model_zoo.load_model('ocr', '')
|
self.model_zoo.load_models('ocr',
|
||||||
self.model_zoo.load_model('num_to_char')
|
'num_to_char',
|
||||||
self.model_zoo.load_model('characters')
|
'characters',
|
||||||
self.end_character = len(self.model_zoo.get('characters', list)) + 2
|
device=device)
|
||||||
|
self.end_character = len(self.model_zoo.get('characters')) + 2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
assert torch
|
return self.model_zoo.get('ocr').device
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.logger.info("Using GPU acceleration")
|
|
||||||
return torch.device("cuda:0")
|
|
||||||
else:
|
|
||||||
self.logger.info("Using CPU processing")
|
|
||||||
return torch.device("cpu")
|
|
||||||
|
|
||||||
def run_trocr(
|
def run_trocr(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue