mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-11-09 22:24:13 +01:00
wip
This commit is contained in:
parent
ec1fd93dad
commit
294b6356d3
6 changed files with 41 additions and 9 deletions
1
Makefile
1
Makefile
|
|
@ -73,7 +73,6 @@ install-dev:
|
|||
|
||||
deps-test: $(EYNOLLAH_MODELS_ZIP)
|
||||
$(PIP) install -r requirements-test.txt
|
||||
endif
|
||||
|
||||
smoke-test: TMPDIR != mktemp -d
|
||||
smoke-test: tests/resources/kant_aufklaerung_1784_0020.tif
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ import tensorflow as tf
|
|||
tf.get_logger().setLevel("ERROR")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .model_zoo import (EynollahModelZoo, KerasModel, TrOCRProcessor)
|
||||
from .utils.contour import (
|
||||
filter_contours_area_of_image,
|
||||
filter_contours_area_of_image_tables,
|
||||
|
|
@ -178,6 +178,7 @@ class Eynollah:
|
|||
self.full_layout = full_layout
|
||||
self.tables = tables
|
||||
self.right2left = right2left
|
||||
# --input-binary sensible if image is very dark, if layout is not working.
|
||||
self.input_binary = input_binary
|
||||
self.allow_scaling = allow_scaling
|
||||
self.headers_off = headers_off
|
||||
|
|
@ -3651,7 +3652,15 @@ class Eynollah:
|
|||
pass
|
||||
|
||||
def return_ocr_of_textline_without_common_section(
|
||||
self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
|
||||
self,
|
||||
textline_image,
|
||||
model_ocr: KerasModel,
|
||||
processor: TrOCRProcessor,
|
||||
device,
|
||||
width_textline,
|
||||
h2w_ratio,
|
||||
ind_tot,
|
||||
):
|
||||
|
||||
if h2w_ratio > 0.05:
|
||||
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
||||
|
|
|
|||
|
|
@ -63,8 +63,11 @@ class Eynollah_ocr:
|
|||
logger: Optional[Logger]=None,
|
||||
):
|
||||
self.tr_ocr = tr_ocr
|
||||
# For generating textline-image pairs for traning, move to generate_gt_for_training
|
||||
self.export_textline_images_and_text = export_textline_images_and_text
|
||||
# masking for OCR and GT generation, relevant for skewed lines and bounding boxes
|
||||
self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour
|
||||
# prefix or dataset
|
||||
self.pref_of_dataset = pref_of_dataset
|
||||
self.logger = logger if logger else getLogger('eynollah')
|
||||
self.model_zoo = EynollahModelZoo(basedir=dir_models)
|
||||
|
|
@ -103,6 +106,7 @@ class Eynollah_ocr:
|
|||
|
||||
def run(self, overwrite: bool = False,
|
||||
dir_in: Optional[str] = None,
|
||||
# Prediction with RGB and binarized images for selected pages, should not be the default
|
||||
dir_in_bin: Optional[str] = None,
|
||||
image_filename: Optional[str] = None,
|
||||
dir_xmls: Optional[str] = None,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
__all__ = [
|
||||
'EynollahModelZoo',
|
||||
'KerasModel',
|
||||
'TrOCRProcessor',
|
||||
'VisionEncoderDecoderModel',
|
||||
]
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .model_zoo import EynollahModelZoo, KerasModel, TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from keras.models import Model as KerasModel
|
|||
from keras.models import load_model
|
||||
from tabulate import tabulate
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
from ..patch_encoder import PatchEncoder, Patches
|
||||
from .specs import EynollahModelSpecSet
|
||||
from .default_specs import DEFAULT_MODEL_SPECS
|
||||
|
|
@ -100,7 +99,7 @@ class EynollahModelZoo:
|
|||
elif model_category == 'characters':
|
||||
model = self._load_characters()
|
||||
elif model_category == 'trocr_processor':
|
||||
return TrOCRProcessor.from_pretrained(self.model_path(...))
|
||||
model = TrOCRProcessor.from_pretrained(model_path)
|
||||
else:
|
||||
try:
|
||||
model = load_model(model_path, compile=False)
|
||||
|
|
@ -184,6 +183,5 @@ class EynollahModelZoo:
|
|||
Ensure that a loaded models is not referenced by ``self._loaded`` anymore
|
||||
"""
|
||||
if hasattr(self, '_loaded') and getattr(self, '_loaded'):
|
||||
for needle in self._loaded:
|
||||
if self._loaded[needle]:
|
||||
del self._loaded[needle]
|
||||
for needle in self._loaded.keys():
|
||||
del self._loaded[needle]
|
||||
|
|
|
|||
19
tests/test_model_zoo.py
Normal file
19
tests/test_model_zoo.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from pathlib import Path
|
||||
|
||||
from eynollah.model_zoo import EynollahModelZoo, TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
testdir = Path(__file__).parent.resolve()
|
||||
MODELS_DIR = testdir.parent
|
||||
|
||||
def test_trocr1():
|
||||
model_zoo = EynollahModelZoo(str(MODELS_DIR))
|
||||
model_zoo.load_model('trocr_processor')
|
||||
proc = model_zoo.get('trocr_processor', TrOCRProcessor)
|
||||
assert isinstance(proc, TrOCRProcessor)
|
||||
|
||||
model_zoo.load_model('ocr', 'tr')
|
||||
model = model_zoo.get('ocr')
|
||||
assert isinstance(model, VisionEncoderDecoderModel)
|
||||
print(proc)
|
||||
|
||||
test_trocr1()
|
||||
Loading…
Add table
Add a link
Reference in a new issue