From 4e9a1618c355a7aeed471c9f63018440adf441cf Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Fri, 10 Oct 2025 03:18:09 +0200 Subject: [PATCH] layout: refactor model setup, allow loading custom versions - simplify definition of (defaults for) model versions - unify loading of loadable models (depending on mode) - use `self.models` dict instead of `self.model_*` attributes - add `model_versions` kwarg / `--model_version` CLI option --- CHANGELOG.md | 1 + src/eynollah/cli.py | 10 +- src/eynollah/eynollah.py | 362 +++++++++++++++++++-------------------- 3 files changed, 191 insertions(+), 182 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fd3b2e..df1e12e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ f458e3e (so CUDA memory gets freed between tests if running on GPU) Added: + * :fire: `layout` CLI: new option `--model_version` to override default choices * test coverage for OCR options in `layout` * test coverage for table detection in `layout` * CI linting with ruff diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index 93bb676..c9bad52 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -202,6 +202,13 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low type=click.Path(exists=True, file_okay=False), required=True, ) +@click.option( + "--model_version", + "-mv", + help="override default versions of model categories", + type=(str, str), + multiple=True, +) @click.option( "--save_images", "-si", @@ -373,7 +380,7 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low help="Setup a basic console logger", ) -def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging): +def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging): if setup_logging: console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) @@ -404,6 +411,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_ assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." eynollah = Eynollah( model, + model_versions=model_version, extract_only_images=extract_only_images, enable_plotting=enable_plotting, allow_enhancement=allow_enhancement, diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 3579078..0992c8c 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -19,7 +19,7 @@ import math import os import sys import time -from typing import Optional +from typing import Dict, List, Optional, Tuple import atexit import warnings from functools import partial @@ -180,7 +180,6 @@ class Patches(layers.Layer): }) return config - class PatchEncoder(layers.Layer): def __init__(self, **kwargs): super(PatchEncoder, self).__init__() @@ -208,6 +207,7 @@ class Eynollah: def __init__( self, dir_models : str, + model_versions: List[Tuple[str, str]] = [], extract_only_images : bool =False, enable_plotting : bool = False, allow_enhancement : bool = False, @@ -254,6 +254,10 @@ class Eynollah: self.skip_layout_and_reading_order = skip_layout_and_reading_order self.ocr = do_ocr self.tr = transformer_ocr + if not batch_size_ocr: + self.b_s_ocr = 8 + else: + self.b_s_ocr = int(batch_size_ocr) if num_col_upper: self.num_col_upper = int(num_col_upper) else: @@ -275,69 +279,6 @@ class Eynollah: self.threshold_art_class_textline = float(threshold_art_class_textline) else: self.threshold_art_class_textline = 0.1 - - self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425" - self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425" - self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425" - self.model_region_dir_p = dir_models + "/eynollah-main-regions-aug-scaling_20210425" - self.model_region_dir_p2 = dir_models + "/eynollah-main-regions-aug-rotation_20210425" - #"/modelens_full_lay_1_3_031124" - #"/modelens_full_lay_13__3_19_241024" - #"/model_full_lay_13_241024" - #"/modelens_full_lay_13_17_231024" - #"/modelens_full_lay_1_2_221024" - #"/eynollah-full-regions-1column_20210425" - self.model_region_dir_fully_np = dir_models + "/modelens_full_lay_1__4_3_091124" - #self.model_region_dir_fully = dir_models + "/eynollah-full-regions-3+column_20210425" - self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915" - self.model_region_dir_p_ens = dir_models + "/eynollah-main-regions-ensembled_20210425" - self.model_region_dir_p_ens_light = dir_models + "/eynollah-main-regions_20220314" - self.model_region_dir_p_ens_light_only_images_extraction = (dir_models + - "/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" - ) - self.model_reading_order_dir = (dir_models + - "/model_eynollah_reading_order_20250824" - #"/model_mb_ro_aug_ens_11" - #"/model_step_3200000_mb_ro" - #"/model_ens_reading_order_machine_based" - #"/model_mb_ro_aug_ens_8" - #"/model_ens_reading_order_machine_based" - ) - #"/modelens_12sp_elay_0_3_4__3_6_n" - #"/modelens_earlylayout_12spaltige_2_3_5_6_7_8" - #"/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18" - #"/modelens_1_2_4_5_early_lay_1_2_spaltige" - #"/model_3_eraly_layout_no_patches_1_2_spaltige" - self.model_region_dir_p_1_2_sp_np = dir_models + "/modelens_e_l_all_sp_0_1_2_3_4_171024" - ##self.model_region_dir_fully_new = dir_models + "/model_2_full_layout_new_trans" - #"/modelens_full_lay_1_3_031124" - #"/modelens_full_lay_13__3_19_241024" - #"/model_full_lay_13_241024" - #"/modelens_full_lay_13_17_231024" - #"/modelens_full_lay_1_2_221024" - #"/modelens_full_layout_24_till_28" - #"/model_2_full_layout_new_trans" - self.model_region_dir_fully = dir_models + "/modelens_full_lay_1__4_3_091124" - if self.textline_light: - #"/modelens_textline_1_4_16092024" - #"/model_textline_ens_3_4_5_6_artificial" - #"/modelens_textline_1_3_4_20240915" - #"/model_textline_ens_3_4_5_6_artificial" - #"/modelens_textline_9_12_13_14_15" - #"/eynollah-textline_light_20210425" - self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024" - else: - #"/eynollah-textline_20210425" - self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024" - if self.ocr and self.tr: - self.model_ocr_dir = dir_models + "/model_eynollah_ocr_trocr_20250919" - elif self.ocr and not self.tr: - self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250930" - if self.tables: - if self.light_version: - self.model_table_dir = dir_models + "/modelens_table_0t4_201124" - else: - self.model_table_dir = dir_models + "/eynollah-tables_20210319" t_start = time.time() @@ -356,28 +297,124 @@ class Eynollah: self.logger.warning("no GPU device available") self.logger.info("Loading models...") - - self.model_page = self.our_load_model(self.model_page_dir) - self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) - self.model_bin = self.our_load_model(self.model_dir_of_binarization) - if self.extract_only_images: - self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction) - else: - self.model_textline = self.our_load_model(self.model_textline_dir) + self.setup_models(dir_models, model_versions) + self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") + + @staticmethod + def our_load_model(model_file, basedir=""): + if basedir: + model_file = os.path.join(basedir, model_file) + if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): + # prefer SavedModel over HDF5 format if it exists + model_file = model_file[:-3] + try: + model = load_model(model_file, compile=False) + except: + model = load_model(model_file, compile=False, custom_objects={ + "PatchEncoder": PatchEncoder, "Patches": Patches}) + return model + + def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []): + self.model_versions = { + "enhancement": "eynollah-enhancement_20210425", + "binarization": "eynollah-binarization_20210425", + "col_classifier": "eynollah-column-classifier_20210425", + "page": "model_eynollah_page_extraction_20250915", + #?: "eynollah-main-regions-aug-scaling_20210425", + "region": ( # early layout + "eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else + "eynollah-main-regions_20220314" if self.light_version else + "eynollah-main-regions-ensembled_20210425"), + "region_p2": ( # early layout, non-light, 2nd part + "eynollah-main-regions-aug-rotation_20210425"), + "region_1_2": ( # early layout, light, 1-or-2-column + #"modelens_12sp_elay_0_3_4__3_6_n" + #"modelens_earlylayout_12spaltige_2_3_5_6_7_8" + #"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18" + #"modelens_1_2_4_5_early_lay_1_2_spaltige" + #"model_3_eraly_layout_no_patches_1_2_spaltige" + "modelens_e_l_all_sp_0_1_2_3_4_171024"), + "region_fl_np": ( # full layout / no patches + #"modelens_full_lay_1_3_031124" + #"modelens_full_lay_13__3_19_241024" + #"model_full_lay_13_241024" + #"modelens_full_lay_13_17_231024" + #"modelens_full_lay_1_2_221024" + #"eynollah-full-regions-1column_20210425" + "modelens_full_lay_1__4_3_091124"), + "region_fl": ( # full layout / with patches + #"eynollah-full-regions-3+column_20210425" + ##"model_2_full_layout_new_trans" + #"modelens_full_lay_1_3_031124" + #"modelens_full_lay_13__3_19_241024" + #"model_full_lay_13_241024" + #"modelens_full_lay_13_17_231024" + #"modelens_full_lay_1_2_221024" + #"modelens_full_layout_24_till_28" + #"model_2_full_layout_new_trans" + "modelens_full_lay_1__4_3_091124"), + "reading_order": ( + #"model_mb_ro_aug_ens_11" + #"model_step_3200000_mb_ro" + #"model_ens_reading_order_machine_based" + #"model_mb_ro_aug_ens_8" + #"model_ens_reading_order_machine_based" + "model_eynollah_reading_order_20250824"), + "textline": ( + #"modelens_textline_1_4_16092024" + #"model_textline_ens_3_4_5_6_artificial" + #"modelens_textline_1_3_4_20240915" + #"model_textline_ens_3_4_5_6_artificial" + #"modelens_textline_9_12_13_14_15" + #"eynollah-textline_light_20210425" + "modelens_textline_0_1__2_4_16092024" if self.textline_light else + #"eynollah-textline_20210425" + "modelens_textline_0_1__2_4_16092024"), + "table": ( + None if not self.tables else + "modelens_table_0t4_201124" if self.light_version else + "eynollah-tables_20210319"), + "ocr": ( + None if not self.ocr else + "model_eynollah_ocr_trocr_20250919" if self.tr else + "model_eynollah_ocr_cnnrnn_20250930") + } + # override defaults from CLI + for key, val in model_versions: + assert key in self.model_versions, "unknown model category '%s'" % key + self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val) + self.model_versions[key] = val + # load models, depending on modes + loadable = [ + "col_classifier", + "binarization", + "page", + "region" + ] + if not self.extract_only_images: + loadable.append("textline") if self.light_version: - self.model_region = self.our_load_model(self.model_region_dir_p_ens_light) - self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np) + loadable.append("region_1_2") else: - self.model_region = self.our_load_model(self.model_region_dir_p_ens) - self.model_region_p2 = self.our_load_model(self.model_region_dir_p2) - self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) - ###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new) - self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) - self.model_region_fl = self.our_load_model(self.model_region_dir_fully) + loadable.append("region_p2") + # if self.allow_enhancement:? + loadable.append("enhancement") + if self.full_layout: + loadable.extend(["region_fl_np", + "region_fl"]) if self.reading_order_machine_based: - self.model_reading_order = self.our_load_model(self.model_reading_order_dir) - if self.ocr and self.tr: - self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) + loadable.append("reading_order") + if self.tables: + loadable.append("table") + + self.models = {name: self.our_load_model(self.model_versions[name], basedir) + for name in loadable + } + + if self.ocr: + ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"]) + if self.tr: + self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) if torch.cuda.is_available(): self.logger.info("Using GPU acceleration") self.device = torch.device("cuda:0") @@ -386,54 +423,29 @@ class Eynollah: self.device = torch.device("cpu") #self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") - elif self.ocr and not self.tr: - model_ocr = load_model(self.model_ocr_dir , compile=False) - - self.prediction_model = tf.keras.models.Model( - model_ocr.get_layer(name = "image").input, - model_ocr.get_layer(name = "dense2").output) - if not batch_size_ocr: - self.b_s_ocr = 8 - else: - self.b_s_ocr = int(batch_size_ocr) + else: + ocr_model = load_model(ocr_model_dir, compile=False) + self.models["ocr"] = tf.keras.models.Model( + ocr_model.get_layer(name = "image").input, + ocr_model.get_layer(name = "dense2").output) - with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file: + with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file: characters = json.load(config_file) - - AUTOTUNE = tf.data.AUTOTUNE - # Mapping characters to integers. char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) - # Mapping integers back to original characters. self.num_to_char = StringLookup( vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True ) - - if self.tables: - self.model_table = self.our_load_model(self.model_table_dir) - - self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") def __del__(self): if hasattr(self, 'executor') and getattr(self, 'executor'): self.executor.shutdown() - for model_name in ['model_page', - 'model_classifier', - 'model_bin', - 'model_enhancement', - 'model_region', - 'model_region_1_2', - 'model_region_p2', - 'model_region_fl_np', - 'model_region_fl', - 'model_textline', - 'model_reading_order', - 'model_table', - 'model_ocr', - 'processor']: - if hasattr(self, model_name) and getattr(self, model_name): - delattr(self, model_name) + self.executor = None + if hasattr(self, 'models') and getattr(self, 'models'): + for model_name in list(self.models): + if self.models[model_name]: + del self.models[model_name] def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} @@ -480,8 +492,8 @@ class Eynollah: def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - img_height_model = self.model_enhancement.layers[-1].output_shape[1] - img_width_model = self.model_enhancement.layers[-1].output_shape[2] + img_height_model = self.models["enhancement"].layers[-1].output_shape[1] + img_width_model = self.models["enhancement"].layers[-1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) if img.shape[1] < img_width_model: @@ -522,7 +534,7 @@ class Eynollah: index_y_d = img_h - img_height_model img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) + label_p_pred = self.models["enhancement"].predict(img_patch, verbose=0) seg = label_p_pred[0, :, :, :] * 255 if i == 0 and j == 0: @@ -697,7 +709,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 self.logger.info("Found %s columns (%s)", num_col, label_p_pred) @@ -715,7 +727,7 @@ class Eynollah: self.logger.info("Detected %s DPI", dpi) if self.input_binary: img = self.imread() - prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5) + prediction_bin = self.do_prediction(True, img, self.models["binarization"], n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) img= np.copy(prediction_bin) @@ -755,7 +767,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): @@ -776,7 +788,7 @@ class Eynollah: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 if num_col > self.num_col_upper: @@ -1628,7 +1640,7 @@ class Eynollah: cont_page = [] if not self.ignore_page_extraction: img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.model_page) + img_page_prediction = self.do_prediction(False, img, self.models["page"]) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) ##thresh = cv2.dilate(thresh, KERNEL, iterations=3) @@ -1676,7 +1688,7 @@ class Eynollah: else: img = self.imread() img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.model_page) + img_page_prediction = self.do_prediction(False, img, self.models["page"]) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -1702,7 +1714,7 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.model_region_fl if patches else self.model_region_fl_np + model_region = self.models["region_fl"] if patches else self.models["region_fl_np"] if self.light_version: thresholding_for_fl_light_version = True @@ -1737,7 +1749,7 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.model_region_fl if patches else self.model_region_fl_np + model_region = self.models["region_fl"] if patches else self.models["region_fl_np"] if not patches: img = otsu_copy_binary(img) @@ -1958,14 +1970,14 @@ class Eynollah: img_w = img_org.shape[1] img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) - prediction_textline = self.do_prediction(use_patches, img, self.model_textline, + prediction_textline = self.do_prediction(use_patches, img, self.models["textline"], marginal_of_patch_percent=0.15, n_batch_inference=3, thresholding_for_artificial_class_in_light_version=self.textline_light, threshold_art_class_textline=self.threshold_art_class_textline) #if not self.textline_light: #if num_col_classifier==1: - #prediction_textline_nopatch = self.do_prediction(False, img, self.model_textline) + #prediction_textline_nopatch = self.do_prediction(False, img, self.models["textline"]) #prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0 prediction_textline = resize_image(prediction_textline, img_h, img_w) @@ -2036,7 +2048,7 @@ class Eynollah: #cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0]) - prediction_textline_longshot = self.do_prediction(False, img, self.model_textline) + prediction_textline_longshot = self.do_prediction(False, img, self.models["textline"]) prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w) @@ -2069,7 +2081,7 @@ class Eynollah: img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new) img_resized = resize_image(img,img_h_new, img_w_new ) - prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_region) + prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.models["region"]) prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h ) image_page, page_coord, cont_page = self.extract_page() @@ -2185,7 +2197,7 @@ class Eynollah: #if self.input_binary: #img_bin = np.copy(img_resized) ###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30): - ###prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5) + ###prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5) ####print("inside bin ", time.time()-t_bin) ###prediction_bin=prediction_bin[:,:,0] @@ -2200,7 +2212,7 @@ class Eynollah: ###else: ###img_bin = np.copy(img_resized) if (self.ocr and self.tr) and not self.input_binary: - prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) prediction_bin = prediction_bin.astype(np.uint16) @@ -2232,14 +2244,14 @@ class Eynollah: self.logger.debug("resized to %dx%d for %d cols", img_resized.shape[1], img_resized.shape[0], num_col_classifier) prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.model_region_1_2, n_batch_inference=1, + True, img_resized, self.models["region_1_2"], n_batch_inference=1, thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) else: prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1])) prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept( - False, self.image_page_org_size, self.model_region_1_2, n_batch_inference=1, + False, self.image_page_org_size, self.models["region_1_2"], n_batch_inference=1, thresholding_for_artificial_class_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) ys = slice(*self.page_coord[0:2]) @@ -2253,10 +2265,10 @@ class Eynollah: self.logger.debug("resized to %dx%d (new_h=%d) for %d cols", img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier) prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.model_region_1_2, n_batch_inference=2, + True, img_resized, self.models["region_1_2"], n_batch_inference=2, thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout) - ###prediction_regions_org = self.do_prediction(True, img_bin, self.model_region, + ###prediction_regions_org = self.do_prediction(True, img_bin, self.models["region"], ###n_batch_inference=3, ###thresholding_for_some_classes_in_light_version=True) #print("inside 3 ", time.time()-t_in) @@ -2336,7 +2348,7 @@ class Eynollah: ratio_x=1 img = resize_image(img_org, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org_y = self.do_prediction(True, img, self.model_region) + prediction_regions_org_y = self.do_prediction(True, img, self.models["region"]) prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h ) #plt.imshow(prediction_regions_org_y[:,:,0]) @@ -2351,7 +2363,7 @@ class Eynollah: _, _ = find_num_col(img_only_regions, num_col_classifier, self.tables, multiplier=6.0) img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]*(1.2 if is_image_enhanced else 1))) - prediction_regions_org = self.do_prediction(True, img, self.model_region) + prediction_regions_org = self.do_prediction(True, img, self.models["region"]) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] @@ -2359,7 +2371,7 @@ class Eynollah: img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1])) - prediction_regions_org2 = self.do_prediction(True, img, self.model_region_p2, marginal_of_patch_percent=0.2) + prediction_regions_org2 = self.do_prediction(True, img, self.models["region_p2"], marginal_of_patch_percent=0.2) prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h ) mask_zeros2 = (prediction_regions_org2[:,:,0] == 0) @@ -2383,7 +2395,7 @@ class Eynollah: if self.input_binary: prediction_bin = np.copy(img_org) else: - prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5) prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) @@ -2393,7 +2405,7 @@ class Eynollah: img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org = self.do_prediction(True, img, self.model_region) + prediction_regions_org = self.do_prediction(True, img, self.models["region"]) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] @@ -2420,7 +2432,7 @@ class Eynollah: except: if self.input_binary: prediction_bin = np.copy(img_org) - prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5) + prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5) prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) @@ -2431,14 +2443,14 @@ class Eynollah: img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) - prediction_regions_org = self.do_prediction(True, img, self.model_region) + prediction_regions_org = self.do_prediction(True, img, self.models["region"]) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org=prediction_regions_org[:,:,0] #mask_lines_only=(prediction_regions_org[:,:]==3)*1 #img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1)) - #prediction_regions_org = self.do_prediction(True, img, self.model_region) + #prediction_regions_org = self.do_prediction(True, img, self.models["region"]) #prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) #prediction_regions_org = prediction_regions_org[:,:,0] #prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0 @@ -2809,13 +2821,13 @@ class Eynollah: img_width_h = img_org.shape[1] patches = False if self.light_version: - prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_table) + prediction_table, _ = self.do_prediction_new_concept(patches, img, self.models["table"]) prediction_table = prediction_table.astype(np.int16) return prediction_table[:,:,0] else: if num_col_classifier < 4 and num_col_classifier > 2: - prediction_table = self.do_prediction(patches, img, self.model_table) - pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table) + prediction_table = self.do_prediction(patches, img, self.models["table"]) + pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"]) pre_updown = cv2.flip(pre_updown, -1) prediction_table[:,:,0][pre_updown[:,:,0]==1]=1 @@ -2834,8 +2846,8 @@ class Eynollah: xs = slice(w_start, w_start + img.shape[1]) img_new[ys, xs] = img - prediction_ext = self.do_prediction(patches, img_new, self.model_table) - pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table) + prediction_ext = self.do_prediction(patches, img_new, self.models["table"]) + pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"]) pre_updown = cv2.flip(pre_updown, -1) prediction_table = prediction_ext[ys, xs] @@ -2856,8 +2868,8 @@ class Eynollah: xs = slice(w_start, w_start + img.shape[1]) img_new[ys, xs] = img - prediction_ext = self.do_prediction(patches, img_new, self.model_table) - pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table) + prediction_ext = self.do_prediction(patches, img_new, self.models["table"]) + pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"]) pre_updown = cv2.flip(pre_updown, -1) prediction_table = prediction_ext[ys, xs] @@ -2869,10 +2881,10 @@ class Eynollah: prediction_table = np.zeros(img.shape) img_w_half = img.shape[1] // 2 - pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.model_table) - pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.model_table) - pre_full = self.do_prediction(patches, img[:,:,:], self.model_table) - pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table) + pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.models["table"]) + pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.models["table"]) + pre_full = self.do_prediction(patches, img[:,:,:], self.models["table"]) + pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"]) pre_updown = cv2.flip(pre_updown, -1) prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4) @@ -3474,18 +3486,6 @@ class Eynollah: regions_without_separators_d, regions_fully, regions_without_separators, polygons_of_marginals, contours_tables) - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model - def do_order_of_regions_with_model(self, contours_only_text_parent, contours_only_text_parent_h, text_regions_p): height1 =672#448 @@ -3676,7 +3676,7 @@ class Eynollah: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_reading_order.predict(input_1 , verbose=0) + y_pr = self.models["reading_order"].predict(input_1 , verbose=0) for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) @@ -4259,7 +4259,7 @@ class Eynollah: gc.collect() ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)), - self.prediction_model, self.b_s_ocr, self.num_to_char, textline_light=True) + self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True) else: ocr_all_textlines = None @@ -4768,27 +4768,27 @@ class Eynollah: if len(all_found_textline_polygons): ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons, all_box_coord, - self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) if len(all_found_textline_polygons_marginals_left): ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left, - self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) if len(all_found_textline_polygons_marginals_right): ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right, - self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) if self.full_layout and len(all_found_textline_polygons): ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons_h, all_box_coord_h, - self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) if self.full_layout and len(polygons_of_drop_capitals): ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines( image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)), - self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) + self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) else: if self.light_version: @@ -4800,7 +4800,7 @@ class Eynollah: gc.collect() torch.cuda.empty_cache() - self.model_ocr.to(self.device) + self.models["ocr"].to(self.device) ind_tot = 0 #cv2.imwrite('./img_out.png', image_page) @@ -4837,7 +4837,7 @@ class Eynollah: img_croped = img_poly_on_img[y:y+h, x:x+w, :] #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) text_ocr = self.return_ocr_of_textline_without_common_section( - img_croped, self.model_ocr, self.processor, self.device, w, h2w_ratio, ind_tot) + img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot) ocr_textline_in_textregion.append(text_ocr) ind_tot = ind_tot +1 ocr_all_textlines.append(ocr_textline_in_textregion)