simplify dir_in conditionals

pull/142/head
Robert Sachunsky 3 weeks ago
parent 7ae64f3717
commit 3b9a29bc5c

@ -274,7 +274,8 @@ class Eynollah:
self.models = {} self.models = {}
if dir_in and light_version: if dir_in:
# as in start_new_session:
config = tf.compat.v1.ConfigProto() config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) session = tf.compat.v1.Session(config=config)
@ -283,62 +284,31 @@ class Eynollah:
self.model_page = self.our_load_model(self.model_page_dir) self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization) self.model_bin = self.our_load_model(self.model_dir_of_binarization)
self.model_textline = self.our_load_model(self.model_textline_dir) if self.extract_only_images:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light) self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction)
self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np) else:
###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new) self.model_textline = self.our_load_model(self.model_textline_dir)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) if self.light_version:
self.model_region_fl = self.our_load_model(self.model_region_dir_fully) self.model_region = self.our_load_model(self.model_region_dir_p_ens_light)
self.model_reading_order = self.our_load_model(self.model_reading_order_dir) self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np)
if self.ocr: else:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) self.model_region = self.our_load_model(self.model_region_dir_p_ens)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten") self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
if self.tables: ###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new)
self.model_table = self.our_load_model(self.model_table_dir) self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
if self.reading_order_machine_based:
self.ls_imgs = os.listdir(self.dir_in) self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.ocr:
if dir_in and self.extract_only_images: self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
config = tf.compat.v1.ConfigProto() self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config.gpu_options.allow_growth = True self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")
session = tf.compat.v1.Session(config=config) if self.tables:
set_session(session) self.model_table = self.our_load_model(self.model_table_dir)
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
#self.model_textline = self.our_load_model(self.model_textline_dir)
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction)
#self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
#self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.ls_imgs = os.listdir(self.dir_in)
if dir_in and not (light_version or self.extract_only_images):
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
set_session(session)
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
self.model_textline = self.our_load_model(self.model_textline_dir)
self.model_region = self.our_load_model(self.model_region_dir_p_ens)
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.tables:
self.model_table = self.our_load_model(self.model_table_dir)
self.ls_imgs = os.listdir(self.dir_in) self.ls_imgs = os.listdir(self.dir_in)
def _cache_images(self, image_filename=None, image_pil=None): def _cache_images(self, image_filename=None, image_pil=None):
ret = {} ret = {}

Loading…
Cancel
Save