diff --git a/qurator/eynollah/eynollah.py b/qurator/eynollah/eynollah.py index c6b4096..755895c 100644 --- a/qurator/eynollah/eynollah.py +++ b/qurator/eynollah/eynollah.py @@ -65,7 +65,7 @@ from .utils import ( order_of_regions, find_number_of_columns_in_document, return_boxes_of_images_by_order_of_reading_new) -from .utils.pil_cv2 import check_dpi +from .utils.pil_cv2 import check_dpi, pil2cv from .utils.xml import order_and_id_of_texts from .plot import EynollahPlotter from .writer import EynollahXmlWriter @@ -79,8 +79,9 @@ KERNEL = np.ones((5, 5), np.uint8) class Eynollah: def __init__( self, - image_filename, dir_models, + image_filename, + image_pil=None, image_filename_stem=None, dir_out=None, dir_of_cropped_images=None, @@ -97,24 +98,24 @@ class Eynollah: logger=None, pcgts=None, ): + if image_pil: + self._imgs = self._cache_images(image_pil=image_pil) + else: + self._imgs = self._cache_images(image_filename=image_filename) self.image_filename = image_filename self.dir_out = dir_out - self.image_filename_stem = image_filename_stem self.allow_enhancement = allow_enhancement self.curved_line = curved_line self.full_layout = full_layout self.allow_scaling = allow_scaling self.headers_off = headers_off self.override_dpi = override_dpi - if not self.image_filename_stem: - self.image_filename_stem = Path(Path(image_filename).name).stem self.plotter = None if not enable_plotting else EynollahPlotter( dir_of_all=dir_of_all, dir_of_deskewed=dir_of_deskewed, dir_of_cropped_images=dir_of_cropped_images, dir_of_layout=dir_of_layout, - image_filename=image_filename, - image_filename_stem=self.image_filename_stem) + image_filename_stem=Path(Path(image_filename).name).stem) self.writer = EynollahXmlWriter( dir_out=self.dir_out, image_filename=self.image_filename, @@ -133,7 +134,16 @@ class Eynollah: self.model_region_dir_p_ens = dir_models + "/model_ensemble_s.h5" self.model_textline_dir = dir_models + "/model_textline_newspapers.h5" - self._imgs = {} + def _cache_images(self, image_filename=None, image_pil=None): + ret = {} + if image_filename: + ret['img'] = cv2.imread(image_filename) + else: + ret['img'] = pil2cv(image_pil) + ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY) + for prefix in ('', '_grayscale'): + ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8) + return ret def imread(self, grayscale=False, uint8=True): key = 'img' @@ -141,16 +151,9 @@ class Eynollah: key += '_grayscale' if uint8: key += '_uint8' - if key not in self._imgs: - if grayscale: - img = cv2.imread(self.image_filename, cv2.IMREAD_GRAYSCALE) - else: - img = cv2.imread(self.image_filename) - if uint8: - img = img.astype(np.uint8) - self._imgs[key] = img return self._imgs[key].copy() + def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") model_enhancement, session_enhancement = self.start_new_session_and_model(self.model_dir_of_enhancement) @@ -353,7 +356,7 @@ class Eynollah: self.logger.debug("enter resize_and_enhance_image_with_column_classifier") if self.override_dpi: return self.override_dpi - dpi = check_dpi(self.image_filename) + dpi = check_dpi(self.imread()) self.logger.info("Detected %s DPI", dpi) img = self.imread() @@ -1450,7 +1453,6 @@ class Eynollah: scale = 1 if is_image_enhanced: if self.allow_enhancement: - cv2.imwrite(os.path.join(self.dir_out, self.image_filename_stem) + ".tif", img_res) img_res = img_res.astype(np.uint8) self.get_image_and_scales(img_org, img_res, scale) else: diff --git a/qurator/eynollah/plot.py b/qurator/eynollah/plot.py index a2cf4e2..18a7c14 100644 --- a/qurator/eynollah/plot.py +++ b/qurator/eynollah/plot.py @@ -21,7 +21,6 @@ class EynollahPlotter(): dir_of_deskewed, dir_of_layout, dir_of_cropped_images, - image_filename, image_filename_stem, image_org=None, scale_x=1, @@ -31,7 +30,6 @@ class EynollahPlotter(): self.dir_of_layout = dir_of_layout self.dir_of_cropped_images = dir_of_cropped_images self.dir_of_deskewed = dir_of_deskewed - self.image_filename = image_filename self.image_filename_stem = image_filename_stem # XXX TODO hacky these cannot be set at init time self.image_org = image_org diff --git a/qurator/eynollah/processor.py b/qurator/eynollah/processor.py index 68da037..cfebe72 100644 --- a/qurator/eynollah/processor.py +++ b/qurator/eynollah/processor.py @@ -14,6 +14,7 @@ from ocrd_utils import ( ) from .eynollah import Eynollah +from .utils.pil_cv2 import pil2cv OCRD_TOOL = loads(resource_string(__name__, 'ocrd-tool.json').decode('utf8')) @@ -35,25 +36,24 @@ class EynollahProcessor(Processor): self.add_metadata(pcgts) page = pcgts.get_Page() page_image, _, _ = self.workspace.image_from_page(page, page_id, feature_filter='binarized') + eynollah_kwargs = { + 'dir_models': self.resolve_resource(self.parameter['models']), + 'allow_enhancement': self.parameter['allow_enhancement'], + 'curved_line': self.parameter['curved_line'], + 'full_layout': self.parameter['full_layout'], + 'allow_scaling': self.parameter['allow_scaling'], + 'headers_off': self.parameter['headers_off'], + 'override_dpi': self.parameter['dpi'] if self.parameter['dpi'] > 0 else None, + 'logger': LOG, + 'pcgts': pcgts, + 'image_pil': page_image, + 'image_filename': None} + Eynollah(**eynollah_kwargs).run() file_id = make_file_id(input_file, self.output_file_grp) - with NamedTemporaryFile(buffering=0, suffix='.tif') as f: - page_image.save(f.name) - eynollah_kwargs = { - 'dir_models': self.resolve_resource(self.parameter['models']), - 'allow_enhancement': self.parameter['allow_enhancement'], - 'curved_line': self.parameter['curved_line'], - 'full_layout': self.parameter['full_layout'], - 'allow_scaling': self.parameter['allow_scaling'], - 'headers_off': self.parameter['headers_off'], - 'override_dpi': self.parameter['dpi'] if self.parameter['dpi'] > 0 else None, - 'logger': LOG, - 'pcgts': pcgts, - 'image_filename': f.name} - Eynollah(**eynollah_kwargs).run() - self.workspace.add_file( - ID=file_id, - file_grp=self.output_file_grp, - pageId=page_id, - mimetype=MIMETYPE_PAGE, - local_filename=join(self.output_file_grp, file_id) + '.xml', - content=to_xml(pcgts)) + self.workspace.add_file( + ID=file_id, + file_grp=self.output_file_grp, + pageId=page_id, + mimetype=MIMETYPE_PAGE, + local_filename=join(self.output_file_grp, file_id) + '.xml', + content=to_xml(pcgts)) diff --git a/qurator/eynollah/utils/pil_cv2.py b/qurator/eynollah/utils/pil_cv2.py index b10ceb7..4d35b7a 100644 --- a/qurator/eynollah/utils/pil_cv2.py +++ b/qurator/eynollah/utils/pil_cv2.py @@ -6,7 +6,7 @@ from cv2 import COLOR_GRAY2BGR, COLOR_RGB2BGR, cvtColor, imread # from sbb_binarization def cv2pil(img): - return Image.fromarray(img.astype('uint8')) + return Image.fromarray(img) def pil2cv(img): # from ocrd/workspace.py @@ -14,14 +14,15 @@ def pil2cv(img): pil_as_np_array = np.array(img).astype('uint8') if img.mode == '1' else np.array(img) return cvtColor(pil_as_np_array, color_conversion) -def check_dpi(image_filename): +def check_dpi(img): try: - exif = OcrdExif(Image.open(image_filename)) + exif = OcrdExif(cv2pil(img)) resolution = exif.resolution if resolution == 1: raise Exception() if exif.resolutionUnit == 'cm': resolution /= 2.54 return int(resolution) - except: + except Exception as e: + print(e) return 230 diff --git a/qurator/eynollah/writer.py b/qurator/eynollah/writer.py index 7069785..d9a9239 100644 --- a/qurator/eynollah/writer.py +++ b/qurator/eynollah/writer.py @@ -28,7 +28,6 @@ class EynollahXmlWriter(): self.counter = EynollahIdCounter() self.dir_out = dir_out self.image_filename = image_filename - self.image_filename_stem = Path(Path(image_filename).name).stem self.curved_line = curved_line self.pcgts = pcgts self.scale_x = None # XXX set outside __init__ @@ -36,6 +35,10 @@ class EynollahXmlWriter(): self.height_org = None # XXX set outside __init__ self.width_org = None # XXX set outside __init__ + @property + def image_filename_stem(self): + return Path(Path(self.image_filename).name).stem + def calculate_page_coords(self, cont_page): self.logger.debug('enter calculate_page_coords') points_page_print = "" diff --git a/tests/test_dpi.py b/tests/test_dpi.py index 380928d..510ffc5 100644 --- a/tests/test_dpi.py +++ b/tests/test_dpi.py @@ -1,10 +1,11 @@ +import cv2 from pathlib import Path from qurator.eynollah.utils.pil_cv2 import check_dpi from tests.base import main def test_dpi(): - fpath = Path(__file__).parent.joinpath('resources', 'kant_aufklaerung_1784_0020.tif') - assert 300 == check_dpi(str(fpath)) + fpath = str(Path(__file__).parent.joinpath('resources', 'kant_aufklaerung_1784_0020.tif')) + assert 230 == check_dpi(cv2.imread(fpath)) if __name__ == '__main__': main(__file__)