From 3b9a29bc5c187fe6ae4c41450a0095c3271ec703 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 4 Dec 2024 18:19:54 +0000 Subject: [PATCH] simplify dir_in conditionals --- src/eynollah/eynollah.py | 78 +++++++++++++--------------------------- 1 file changed, 24 insertions(+), 54 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 2dd5505..c1e0f4d 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -274,7 +274,8 @@ class Eynollah: self.models = {} - if dir_in and light_version: + if dir_in: + # as in start_new_session: config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True session = tf.compat.v1.Session(config=config) @@ -283,62 +284,31 @@ class Eynollah: self.model_page = self.our_load_model(self.model_page_dir) self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) self.model_bin = self.our_load_model(self.model_dir_of_binarization) - self.model_textline = self.our_load_model(self.model_textline_dir) - self.model_region = self.our_load_model(self.model_region_dir_p_ens_light) - self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np) - ###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new) - self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) - self.model_region_fl = self.our_load_model(self.model_region_dir_fully) - self.model_reading_order = self.our_load_model(self.model_reading_order_dir) - if self.ocr: - self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten") - if self.tables: - self.model_table = self.our_load_model(self.model_table_dir) - - - self.ls_imgs = os.listdir(self.dir_in) - - if dir_in and self.extract_only_images: - config = tf.compat.v1.ConfigProto() - config.gpu_options.allow_growth = True - session = tf.compat.v1.Session(config=config) - set_session(session) - - self.model_page = self.our_load_model(self.model_page_dir) - self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) - self.model_bin = self.our_load_model(self.model_dir_of_binarization) - #self.model_textline = self.our_load_model(self.model_textline_dir) - self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction) - #self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) - #self.model_region_fl = self.our_load_model(self.model_region_dir_fully) - - self.ls_imgs = os.listdir(self.dir_in) - - if dir_in and not (light_version or self.extract_only_images): - config = tf.compat.v1.ConfigProto() - config.gpu_options.allow_growth = True - session = tf.compat.v1.Session(config=config) - set_session(session) + if self.extract_only_images: + self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction) + else: + self.model_textline = self.our_load_model(self.model_textline_dir) + if self.light_version: + self.model_region = self.our_load_model(self.model_region_dir_p_ens_light) + self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np) + else: + self.model_region = self.our_load_model(self.model_region_dir_p_ens) + self.model_region_p2 = self.our_load_model(self.model_region_dir_p2) + self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) + ###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new) + self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) + self.model_region_fl = self.our_load_model(self.model_region_dir_fully) + if self.reading_order_machine_based: + self.model_reading_order = self.our_load_model(self.model_reading_order_dir) + if self.ocr: + self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten") + if self.tables: + self.model_table = self.our_load_model(self.model_table_dir) - self.model_page = self.our_load_model(self.model_page_dir) - self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) - self.model_bin = self.our_load_model(self.model_dir_of_binarization) - self.model_textline = self.our_load_model(self.model_textline_dir) - self.model_region = self.our_load_model(self.model_region_dir_p_ens) - self.model_region_p2 = self.our_load_model(self.model_region_dir_p2) - self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) - self.model_region_fl = self.our_load_model(self.model_region_dir_fully) - self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) - self.model_reading_order = self.our_load_model(self.model_reading_order_dir) - if self.tables: - self.model_table = self.our_load_model(self.model_table_dir) - self.ls_imgs = os.listdir(self.dir_in) - - def _cache_images(self, image_filename=None, image_pil=None): ret = {}