simplify dir_in conditionals

pull/142/merge^2
Robert Sachunsky 1 month ago
parent 7ae64f3717
commit 3b9a29bc5c

@ -274,7 +274,8 @@ class Eynollah:
self.models = {} self.models = {}
if dir_in and light_version: if dir_in:
# as in start_new_session:
config = tf.compat.v1.ConfigProto() config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) session = tf.compat.v1.Session(config=config)
@ -283,12 +284,21 @@ class Eynollah:
self.model_page = self.our_load_model(self.model_page_dir) self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization) self.model_bin = self.our_load_model(self.model_dir_of_binarization)
if self.extract_only_images:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction)
else:
self.model_textline = self.our_load_model(self.model_textline_dir) self.model_textline = self.our_load_model(self.model_textline_dir)
if self.light_version:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light) self.model_region = self.our_load_model(self.model_region_dir_p_ens_light)
self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np) self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np)
else:
self.model_region = self.our_load_model(self.model_region_dir_p_ens)
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new) ###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully) self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
if self.reading_order_machine_based:
self.model_reading_order = self.our_load_model(self.model_reading_order_dir) self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.ocr: if self.ocr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
@ -297,49 +307,9 @@ class Eynollah:
if self.tables: if self.tables:
self.model_table = self.our_load_model(self.model_table_dir) self.model_table = self.our_load_model(self.model_table_dir)
self.ls_imgs = os.listdir(self.dir_in)
if dir_in and self.extract_only_images:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
set_session(session)
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
#self.model_textline = self.our_load_model(self.model_textline_dir)
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction)
#self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
#self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.ls_imgs = os.listdir(self.dir_in)
if dir_in and not (light_version or self.extract_only_images):
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
set_session(session)
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
self.model_textline = self.our_load_model(self.model_textline_dir)
self.model_region = self.our_load_model(self.model_region_dir_p_ens)
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.tables:
self.model_table = self.our_load_model(self.model_table_dir)
self.ls_imgs = os.listdir(self.dir_in) self.ls_imgs = os.listdir(self.dir_in)
def _cache_images(self, image_filename=None, image_pil=None): def _cache_images(self, image_filename=None, image_pil=None):
ret = {} ret = {}
t_c0 = time.time() t_c0 = time.time()

Loading…
Cancel
Save