mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
typing and asserts
This commit is contained in:
parent
44b75eb36f
commit
c6b863b13f
4 changed files with 30 additions and 23 deletions
|
|
@ -9,11 +9,10 @@ document layout analysis (segmentation) with output in PAGE-XML
|
|||
"""
|
||||
|
||||
from difflib import SequenceMatcher as sq
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Union,List, Optional, Tuple
|
||||
from typing import Dict, Type, Union,List, Optional, Tuple
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
|
@ -32,7 +31,7 @@ from ocrd_utils import getLogger, tf_disable_interactive_logs
|
|||
import statistics
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch # type: ignore
|
||||
except ImportError:
|
||||
torch = None
|
||||
try:
|
||||
|
|
@ -43,13 +42,8 @@ except ImportError:
|
|||
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
from keras.models import load_model
|
||||
tf.get_logger().setLevel("ERROR")
|
||||
warnings.filterwarnings("ignore")
|
||||
# use tf1 compatibility for keras backend
|
||||
from tensorflow.compat.v1.keras.backend import set_session
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.contour import (
|
||||
|
|
@ -280,6 +274,7 @@ class Eynollah:
|
|||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'executor') and getattr(self, 'executor'):
|
||||
assert self.executor
|
||||
self.executor.shutdown()
|
||||
self.executor = None
|
||||
self.model_zoo.shutdown()
|
||||
|
|
@ -287,6 +282,7 @@ class Eynollah:
|
|||
@property
|
||||
def device(self):
|
||||
# TODO why here and why only for tr?
|
||||
assert torch
|
||||
if torch.cuda.is_available():
|
||||
self.logger.info("Using GPU acceleration")
|
||||
return torch.device("cuda:0")
|
||||
|
|
@ -689,8 +685,8 @@ class Eynollah:
|
|||
|
||||
self.img_hight_int = int(self.image.shape[0] * scale)
|
||||
self.img_width_int = int(self.image.shape[1] * scale)
|
||||
self.scale_y = self.img_hight_int / float(self.image.shape[0])
|
||||
self.scale_x = self.img_width_int / float(self.image.shape[1])
|
||||
self.scale_y: float = self.img_hight_int / float(self.image.shape[0])
|
||||
self.scale_x: float = self.img_width_int / float(self.image.shape[1])
|
||||
|
||||
self.image = resize_image(self.image, self.img_hight_int, self.img_width_int)
|
||||
|
||||
|
|
@ -1755,6 +1751,7 @@ class Eynollah:
|
|||
return [], [], []
|
||||
self.logger.debug("enter get_slopes_and_deskew_new_light")
|
||||
with share_ndarray(textline_mask_tot) as textline_mask_tot_shared:
|
||||
assert self.executor
|
||||
results = self.executor.map(partial(do_work_of_slopes_new_light,
|
||||
textline_mask_tot_ea=textline_mask_tot_shared,
|
||||
slope_deskew=slope_deskew,
|
||||
|
|
@ -1771,6 +1768,7 @@ class Eynollah:
|
|||
return [], [], []
|
||||
self.logger.debug("enter get_slopes_and_deskew_new")
|
||||
with share_ndarray(textline_mask_tot) as textline_mask_tot_shared:
|
||||
assert self.executor
|
||||
results = self.executor.map(partial(do_work_of_slopes_new,
|
||||
textline_mask_tot_ea=textline_mask_tot_shared,
|
||||
slope_deskew=slope_deskew,
|
||||
|
|
@ -1791,6 +1789,7 @@ class Eynollah:
|
|||
self.logger.debug("enter get_slopes_and_deskew_new_curved")
|
||||
with share_ndarray(textline_mask_tot) as textline_mask_tot_shared:
|
||||
with share_ndarray(mask_texts_only) as mask_texts_only_shared:
|
||||
assert self.executor
|
||||
results = self.executor.map(partial(do_work_of_slopes_new_curved,
|
||||
textline_mask_tot_ea=textline_mask_tot_shared,
|
||||
mask_texts_only=mask_texts_only_shared,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Literal, Optional, Tuple, List, Union
|
||||
from typing import Dict, Literal, Optional, Tuple, List, Type, TypeVar, Union
|
||||
from copy import deepcopy
|
||||
|
||||
from keras.layers import StringLookup
|
||||
|
|
@ -12,7 +12,7 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|||
from eynollah.patch_encoder import PatchEncoder, Patches
|
||||
|
||||
SomeEynollahModel = Union[VisionEncoderDecoderModel, TrOCRProcessor, Model, List]
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# Dict mapping model_category to dict mapping variant (default is '') to Path
|
||||
DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
|
||||
|
|
@ -149,7 +149,10 @@ class EynollahModelZoo():
|
|||
self.override_models(*model_overrides)
|
||||
self._loaded: Dict[str, SomeEynollahModel] = {}
|
||||
|
||||
def override_models(self, *model_overrides: Tuple[str, str, str]):
|
||||
def override_models(
|
||||
self,
|
||||
*model_overrides: Tuple[str, str, str],
|
||||
):
|
||||
"""
|
||||
Override the default model versions
|
||||
"""
|
||||
|
|
@ -235,10 +238,13 @@ class EynollahModelZoo():
|
|||
self._loaded[model_category] = model
|
||||
return model # type: ignore
|
||||
|
||||
def get(self, model_category) -> SomeEynollahModel:
|
||||
def get(self, model_category: str, model_type: Optional[Type[T]]=None) -> T:
|
||||
if model_category not in self._loaded:
|
||||
raise ValueError(f'Model "{model_category} not previously loaded with "load_model(..)"')
|
||||
return self._loaded[model_category]
|
||||
ret = self._loaded[model_category]
|
||||
if model_type:
|
||||
assert isinstance(ret, model_type)
|
||||
return ret # type: ignore # FIXME: convince typing that we're returning generic type
|
||||
|
||||
def _load_ocr_model(self, variant: str) -> SomeEynollahModel:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ class EynollahPlotter:
|
|||
self.image_filename_stem = image_filename_stem
|
||||
# XXX TODO hacky these cannot be set at init time
|
||||
self.image_org = image_org
|
||||
self.scale_x = scale_x
|
||||
self.scale_y = scale_y
|
||||
self.scale_x : float = scale_x
|
||||
self.scale_y : float = scale_y
|
||||
|
||||
def save_plot_of_layout_main(self, text_regions_p, image_page):
|
||||
if self.dir_of_layout is not None:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# pylint: disable=import-error
|
||||
from pathlib import Path
|
||||
import os.path
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Optional
|
||||
from .utils.xml import create_page_xml, xml_reading_order
|
||||
from .utils.counter import EynollahIdCounter
|
||||
|
||||
|
|
@ -10,7 +10,6 @@ from ocrd_utils import getLogger
|
|||
from ocrd_models.ocrd_page import (
|
||||
BorderType,
|
||||
CoordsType,
|
||||
PcGtsType,
|
||||
TextLineType,
|
||||
TextEquivType,
|
||||
TextRegionType,
|
||||
|
|
@ -32,10 +31,10 @@ class EynollahXmlWriter:
|
|||
self.curved_line = curved_line
|
||||
self.textline_light = textline_light
|
||||
self.pcgts = pcgts
|
||||
self.scale_x = None # XXX set outside __init__
|
||||
self.scale_y = None # XXX set outside __init__
|
||||
self.height_org = None # XXX set outside __init__
|
||||
self.width_org = None # XXX set outside __init__
|
||||
self.scale_x: Optional[float] = None # XXX set outside __init__
|
||||
self.scale_y: Optional[float] = None # XXX set outside __init__
|
||||
self.height_org: Optional[int] = None # XXX set outside __init__
|
||||
self.width_org: Optional[int] = None # XXX set outside __init__
|
||||
|
||||
@property
|
||||
def image_filename_stem(self):
|
||||
|
|
@ -135,6 +134,7 @@ class EynollahXmlWriter:
|
|||
# create the file structure
|
||||
pcgts = self.pcgts if self.pcgts else create_page_xml(self.image_filename, self.height_org, self.width_org)
|
||||
page = pcgts.get_Page()
|
||||
assert page
|
||||
page.set_Border(BorderType(Coords=CoordsType(points=self.calculate_page_coords(cont_page))))
|
||||
|
||||
counter = EynollahIdCounter()
|
||||
|
|
@ -152,6 +152,7 @@ class EynollahXmlWriter:
|
|||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord,
|
||||
skip_layout_reading_order))
|
||||
)
|
||||
assert textregion.Coords
|
||||
if conf_contours_textregions:
|
||||
textregion.Coords.set_conf(conf_contours_textregions[mm])
|
||||
page.add_TextRegion(textregion)
|
||||
|
|
@ -168,6 +169,7 @@ class EynollahXmlWriter:
|
|||
id=counter.next_region_id, type_='heading',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(region_contour, page_coord))
|
||||
)
|
||||
assert textregion.Coords
|
||||
if conf_contours_textregions_h:
|
||||
textregion.Coords.set_conf(conf_contours_textregions_h[mm])
|
||||
page.add_TextRegion(textregion)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue