diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index 3436250..93bb676 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -47,14 +47,14 @@ def main(): def machine_based_reading_order(input, dir_in, out, model, log_level): assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - orderer = machine_based_reading_order_on_layout(model, dir_out=out) + orderer = machine_based_reading_order_on_layout(model) if log_level: orderer.logger.setLevel(getLevelName(log_level)) - if dir_in: - orderer.run(dir_in=dir_in) - else: - orderer.run(xml_filename=input) + orderer.run(xml_filename=input, + dir_in=dir_in, + dir_out=out, + ) @main.command() @@ -156,17 +156,17 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low initLogging() enhancer = Enhancer( model, - dir_out=out, num_col_upper=num_col_upper, num_col_lower=num_col_lower, save_org_scale=save_org_scale, ) if log_level: enhancer.logger.setLevel(getLevelName(log_level)) - if dir_in: - enhancer.run(dir_in=dir_in, overwrite=overwrite) - else: - enhancer.run(image_filename=image, overwrite=overwrite) + enhancer.run(overwrite=overwrite, + dir_in=dir_in, + image_filename=image, + dir_out=out, + ) @main.command() @click.option( diff --git a/src/eynollah/image_enhancer.py b/src/eynollah/image_enhancer.py index 5a06d59..89dde16 100644 --- a/src/eynollah/image_enhancer.py +++ b/src/eynollah/image_enhancer.py @@ -11,7 +11,6 @@ from functools import partial from pathlib import Path from multiprocessing import cpu_count import gc -from loky import ProcessPoolExecutor import cv2 import numpy as np from ocrd_utils import getLogger, tf_disable_interactive_logs @@ -33,13 +32,11 @@ class Enhancer: def __init__( self, dir_models : str, - dir_out : Optional[str] = None, num_col_upper : Optional[int] = None, num_col_lower : Optional[int] = None, save_org_scale : bool = False, logger : Optional[Logger] = None, ): - self.dir_out = dir_out self.input_binary = False self.light_version = False self.save_org_scale = save_org_scale @@ -53,9 +50,6 @@ class Enhancer: self.num_col_lower = num_col_lower self.logger = logger if logger else getLogger('enhancement') - # for parallelization of CPU-intensive tasks: - self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200) - atexit.register(self.executor.shutdown) self.dir_models = dir_models self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425" self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425" @@ -94,9 +88,9 @@ class Enhancer: if dpi is not None: self.dpi = dpi - def reset_file_name_dir(self, image_filename): + def reset_file_name_dir(self, image_filename, dir_out): self.cache_images(image_filename=image_filename) - self.output_filename = os.path.join(self.dir_out, Path(image_filename).stem +'.png') + self.output_filename = os.path.join(dir_out, Path(image_filename).stem +'.png') def imread(self, grayscale=False, uint8=True): key = 'img' @@ -694,7 +688,12 @@ class Enhancer: return img_res - def run(self, image_filename : Optional[str] = None, dir_in : Optional[str] = None, overwrite : bool = False): + def run(self, + overwrite: bool = False, + image_filename: Optional[str] = None, + dir_in: Optional[str] = None, + dir_out: Optional[str] = None, + ): """ Get image and scales, then extract the page of scanned image """ @@ -702,7 +701,9 @@ class Enhancer: t0_tot = time.time() if dir_in: - ls_imgs = list(filter(is_image_filename, os.listdir(dir_in))) + ls_imgs = [os.path.join(dir_in, image_filename) + for image_filename in filter(is_image_filename, + os.listdir(dir_in))] elif image_filename: ls_imgs = [image_filename] else: @@ -712,7 +713,7 @@ class Enhancer: self.logger.info(img_filename) t0 = time.time() - self.reset_file_name_dir(os.path.join(dir_in or "", img_filename)) + self.reset_file_name_dir(img_filename, dir_out) #print("text region early -11 in %.1fs", time.time() - t0) if os.path.exists(self.output_filename): diff --git a/src/eynollah/mb_ro_on_layout.py b/src/eynollah/mb_ro_on_layout.py index 6d72614..45db8e4 100644 --- a/src/eynollah/mb_ro_on_layout.py +++ b/src/eynollah/mb_ro_on_layout.py @@ -10,7 +10,6 @@ import atexit from functools import partial from pathlib import Path from multiprocessing import cpu_count -from loky import ProcessPoolExecutor import xml.etree.ElementTree as ET import cv2 import numpy as np @@ -35,15 +34,9 @@ class machine_based_reading_order_on_layout: def __init__( self, dir_models : str, - dir_out : Optional[str] = None, logger : Optional[Logger] = None, ): - self.dir_out = dir_out - self.logger = logger if logger else getLogger('mbreorder') - # for parallelization of CPU-intensive tasks: - self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200) - atexit.register(self.executor.shutdown) self.dir_models = dir_models self.model_reading_order_dir = dir_models + "/model_eynollah_reading_order_20250824" @@ -56,9 +49,6 @@ class machine_based_reading_order_on_layout: self.model_reading_order = self.our_load_model(self.model_reading_order_dir) self.light_version = True - - - @staticmethod def our_load_model(model_file): if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): @@ -70,10 +60,8 @@ class machine_based_reading_order_on_layout: model = load_model(model_file, compile=False, custom_objects={ "PatchEncoder": PatchEncoder, "Patches": Patches}) return model - - + def read_xml(self, xml_file): - file_name = Path(xml_file).stem tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] @@ -495,7 +483,7 @@ class machine_based_reading_order_on_layout: img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4)) img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5)) - return tree1, root1, bb_coord_printspace, file_name, id_paragraph, id_header+id_heading, co_text_paragraph, co_text_header+co_text_heading,\ + return tree1, root1, bb_coord_printspace, id_paragraph, id_header+id_heading, co_text_paragraph, co_text_header+co_text_heading,\ tot_region_ref,x_len, y_len,index_tot_regions, img_poly def return_indexes_of_contours_loctaed_inside_another_list_of_contours(self, contours, contours_loc, cx_main_loc, cy_main_loc, indexes_loc): @@ -744,7 +732,12 @@ class machine_based_reading_order_on_layout: - def run(self, xml_filename : Optional[str] = None, dir_in : Optional[str] = None, overwrite : bool = False): + def run(self, + overwrite: bool = False, + xml_filename: Optional[str] = None, + dir_in: Optional[str] = None, + dir_out: Optional[str] = None, + ): """ Get image and scales, then extract the page of scanned image """ @@ -752,7 +745,9 @@ class machine_based_reading_order_on_layout: t0_tot = time.time() if dir_in: - ls_xmls = list(filter(is_xml_filename, os.listdir(dir_in))) + ls_xmls = [os.path.join(dir_in, xml_filename) + for xml_filename in filter(is_xml_filename, + os.listdir(dir_in))] elif xml_filename: ls_xmls = [xml_filename] else: @@ -761,13 +756,11 @@ class machine_based_reading_order_on_layout: for xml_filename in ls_xmls: self.logger.info(xml_filename) t0 = time.time() - - if dir_in: - xml_file = os.path.join(dir_in, xml_filename) - else: - xml_file = xml_filename - - tree_xml, root_xml, bb_coord_printspace, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = self.read_xml(xml_file) + + file_name = Path(xml_filename).stem + (tree_xml, root_xml, bb_coord_printspace, id_paragraph, id_header, + co_text_paragraph, co_text_header, tot_region_ref, + x_len, y_len, index_tot_regions, img_poly) = self.read_xml(xml_filename) id_all_text = id_paragraph + id_header @@ -810,7 +803,11 @@ class machine_based_reading_order_on_layout: alltags=[elem.tag for elem in root_xml.iter()] ET.register_namespace("",name_space) - tree_xml.write(os.path.join(self.dir_out, file_name+'.xml'),xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + tree_xml.write(os.path.join(dir_out, file_name+'.xml'), + xml_declaration=True, + method='xml', + encoding="utf8", + default_namespace=None) #sys.exit()