layout: refactor model setup, allow loading custom versions

- simplify definition of (defaults for) model versions
- unify loading of loadable models (depending on mode)
- use `self.models` dict instead of `self.model_*` attributes
- add `model_versions` kwarg / `--model_version` CLI option
This commit is contained in:
Robert Sachunsky 2025-10-10 03:18:09 +02:00
parent 374818de11
commit 4e9a1618c3
3 changed files with 191 additions and 182 deletions

View file

@ -25,6 +25,7 @@ f458e3e
(so CUDA memory gets freed between tests if running on GPU)
Added:
* :fire: `layout` CLI: new option `--model_version` to override default choices
* test coverage for OCR options in `layout`
* test coverage for table detection in `layout`
* CI linting with ruff

View file

@ -202,6 +202,13 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--model_version",
"-mv",
help="override default versions of model categories",
type=(str, str),
multiple=True,
)
@click.option(
"--save_images",
"-si",
@ -373,7 +380,7 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
help="Setup a basic console logger",
)
def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging):
def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging):
if setup_logging:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
@ -404,6 +411,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah(
model,
model_versions=model_version,
extract_only_images=extract_only_images,
enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement,

View file

@ -19,7 +19,7 @@ import math
import os
import sys
import time
from typing import Optional
from typing import Dict, List, Optional, Tuple
import atexit
import warnings
from functools import partial
@ -180,7 +180,6 @@ class Patches(layers.Layer):
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
@ -208,6 +207,7 @@ class Eynollah:
def __init__(
self,
dir_models : str,
model_versions: List[Tuple[str, str]] = [],
extract_only_images : bool =False,
enable_plotting : bool = False,
allow_enhancement : bool = False,
@ -254,6 +254,10 @@ class Eynollah:
self.skip_layout_and_reading_order = skip_layout_and_reading_order
self.ocr = do_ocr
self.tr = transformer_ocr
if not batch_size_ocr:
self.b_s_ocr = 8
else:
self.b_s_ocr = int(batch_size_ocr)
if num_col_upper:
self.num_col_upper = int(num_col_upper)
else:
@ -276,69 +280,6 @@ class Eynollah:
else:
self.threshold_art_class_textline = 0.1
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425"
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
self.model_region_dir_p = dir_models + "/eynollah-main-regions-aug-scaling_20210425"
self.model_region_dir_p2 = dir_models + "/eynollah-main-regions-aug-rotation_20210425"
#"/modelens_full_lay_1_3_031124"
#"/modelens_full_lay_13__3_19_241024"
#"/model_full_lay_13_241024"
#"/modelens_full_lay_13_17_231024"
#"/modelens_full_lay_1_2_221024"
#"/eynollah-full-regions-1column_20210425"
self.model_region_dir_fully_np = dir_models + "/modelens_full_lay_1__4_3_091124"
#self.model_region_dir_fully = dir_models + "/eynollah-full-regions-3+column_20210425"
self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915"
self.model_region_dir_p_ens = dir_models + "/eynollah-main-regions-ensembled_20210425"
self.model_region_dir_p_ens_light = dir_models + "/eynollah-main-regions_20220314"
self.model_region_dir_p_ens_light_only_images_extraction = (dir_models +
"/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18"
)
self.model_reading_order_dir = (dir_models +
"/model_eynollah_reading_order_20250824"
#"/model_mb_ro_aug_ens_11"
#"/model_step_3200000_mb_ro"
#"/model_ens_reading_order_machine_based"
#"/model_mb_ro_aug_ens_8"
#"/model_ens_reading_order_machine_based"
)
#"/modelens_12sp_elay_0_3_4__3_6_n"
#"/modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#"/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#"/modelens_1_2_4_5_early_lay_1_2_spaltige"
#"/model_3_eraly_layout_no_patches_1_2_spaltige"
self.model_region_dir_p_1_2_sp_np = dir_models + "/modelens_e_l_all_sp_0_1_2_3_4_171024"
##self.model_region_dir_fully_new = dir_models + "/model_2_full_layout_new_trans"
#"/modelens_full_lay_1_3_031124"
#"/modelens_full_lay_13__3_19_241024"
#"/model_full_lay_13_241024"
#"/modelens_full_lay_13_17_231024"
#"/modelens_full_lay_1_2_221024"
#"/modelens_full_layout_24_till_28"
#"/model_2_full_layout_new_trans"
self.model_region_dir_fully = dir_models + "/modelens_full_lay_1__4_3_091124"
if self.textline_light:
#"/modelens_textline_1_4_16092024"
#"/model_textline_ens_3_4_5_6_artificial"
#"/modelens_textline_1_3_4_20240915"
#"/model_textline_ens_3_4_5_6_artificial"
#"/modelens_textline_9_12_13_14_15"
#"/eynollah-textline_light_20210425"
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"
else:
#"/eynollah-textline_20210425"
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"
if self.ocr and self.tr:
self.model_ocr_dir = dir_models + "/model_eynollah_ocr_trocr_20250919"
elif self.ocr and not self.tr:
self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250930"
if self.tables:
if self.light_version:
self.model_table_dir = dir_models + "/modelens_table_0t4_201124"
else:
self.model_table_dir = dir_models + "/eynollah-tables_20210319"
t_start = time.time()
# #gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
@ -356,28 +297,124 @@ class Eynollah:
self.logger.warning("no GPU device available")
self.logger.info("Loading models...")
self.setup_models(dir_models, model_versions)
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
if self.extract_only_images:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction)
else:
self.model_textline = self.our_load_model(self.model_textline_dir)
@staticmethod
def our_load_model(model_file, basedir=""):
if basedir:
model_file = os.path.join(basedir, model_file)
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []):
self.model_versions = {
"enhancement": "eynollah-enhancement_20210425",
"binarization": "eynollah-binarization_20210425",
"col_classifier": "eynollah-column-classifier_20210425",
"page": "model_eynollah_page_extraction_20250915",
#?: "eynollah-main-regions-aug-scaling_20210425",
"region": ( # early layout
"eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else
"eynollah-main-regions_20220314" if self.light_version else
"eynollah-main-regions-ensembled_20210425"),
"region_p2": ( # early layout, non-light, 2nd part
"eynollah-main-regions-aug-rotation_20210425"),
"region_1_2": ( # early layout, light, 1-or-2-column
#"modelens_12sp_elay_0_3_4__3_6_n"
#"modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#"modelens_1_2_4_5_early_lay_1_2_spaltige"
#"model_3_eraly_layout_no_patches_1_2_spaltige"
"modelens_e_l_all_sp_0_1_2_3_4_171024"),
"region_fl_np": ( # full layout / no patches
#"modelens_full_lay_1_3_031124"
#"modelens_full_lay_13__3_19_241024"
#"model_full_lay_13_241024"
#"modelens_full_lay_13_17_231024"
#"modelens_full_lay_1_2_221024"
#"eynollah-full-regions-1column_20210425"
"modelens_full_lay_1__4_3_091124"),
"region_fl": ( # full layout / with patches
#"eynollah-full-regions-3+column_20210425"
##"model_2_full_layout_new_trans"
#"modelens_full_lay_1_3_031124"
#"modelens_full_lay_13__3_19_241024"
#"model_full_lay_13_241024"
#"modelens_full_lay_13_17_231024"
#"modelens_full_lay_1_2_221024"
#"modelens_full_layout_24_till_28"
#"model_2_full_layout_new_trans"
"modelens_full_lay_1__4_3_091124"),
"reading_order": (
#"model_mb_ro_aug_ens_11"
#"model_step_3200000_mb_ro"
#"model_ens_reading_order_machine_based"
#"model_mb_ro_aug_ens_8"
#"model_ens_reading_order_machine_based"
"model_eynollah_reading_order_20250824"),
"textline": (
#"modelens_textline_1_4_16092024"
#"model_textline_ens_3_4_5_6_artificial"
#"modelens_textline_1_3_4_20240915"
#"model_textline_ens_3_4_5_6_artificial"
#"modelens_textline_9_12_13_14_15"
#"eynollah-textline_light_20210425"
"modelens_textline_0_1__2_4_16092024" if self.textline_light else
#"eynollah-textline_20210425"
"modelens_textline_0_1__2_4_16092024"),
"table": (
None if not self.tables else
"modelens_table_0t4_201124" if self.light_version else
"eynollah-tables_20210319"),
"ocr": (
None if not self.ocr else
"model_eynollah_ocr_trocr_20250919" if self.tr else
"model_eynollah_ocr_cnnrnn_20250930")
}
# override defaults from CLI
for key, val in model_versions:
assert key in self.model_versions, "unknown model category '%s'" % key
self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val)
self.model_versions[key] = val
# load models, depending on modes
loadable = [
"col_classifier",
"binarization",
"page",
"region"
]
if not self.extract_only_images:
loadable.append("textline")
if self.light_version:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light)
self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np)
loadable.append("region_1_2")
else:
self.model_region = self.our_load_model(self.model_region_dir_p_ens)
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
loadable.append("region_p2")
# if self.allow_enhancement:?
loadable.append("enhancement")
if self.full_layout:
loadable.extend(["region_fl_np",
"region_fl"])
if self.reading_order_machine_based:
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.ocr and self.tr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
loadable.append("reading_order")
if self.tables:
loadable.append("table")
self.models = {name: self.our_load_model(self.model_versions[name], basedir)
for name in loadable
}
if self.ocr:
ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"])
if self.tr:
self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
self.device = torch.device("cuda:0")
@ -386,54 +423,29 @@ class Eynollah:
self.device = torch.device("cpu")
#self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
elif self.ocr and not self.tr:
model_ocr = load_model(self.model_ocr_dir , compile=False)
else:
ocr_model = load_model(ocr_model_dir, compile=False)
self.models["ocr"] = tf.keras.models.Model(
ocr_model.get_layer(name = "image").input,
ocr_model.get_layer(name = "dense2").output)
self.prediction_model = tf.keras.models.Model(
model_ocr.get_layer(name = "image").input,
model_ocr.get_layer(name = "dense2").output)
if not batch_size_ocr:
self.b_s_ocr = 8
else:
self.b_s_ocr = int(batch_size_ocr)
with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file:
with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file:
characters = json.load(config_file)
AUTOTUNE = tf.data.AUTOTUNE
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
self.num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
if self.tables:
self.model_table = self.our_load_model(self.model_table_dir)
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
def __del__(self):
if hasattr(self, 'executor') and getattr(self, 'executor'):
self.executor.shutdown()
for model_name in ['model_page',
'model_classifier',
'model_bin',
'model_enhancement',
'model_region',
'model_region_1_2',
'model_region_p2',
'model_region_fl_np',
'model_region_fl',
'model_textline',
'model_reading_order',
'model_table',
'model_ocr',
'processor']:
if hasattr(self, model_name) and getattr(self, model_name):
delattr(self, model_name)
self.executor = None
if hasattr(self, 'models') and getattr(self, 'models'):
for model_name in list(self.models):
if self.models[model_name]:
del self.models[model_name]
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {}
@ -480,8 +492,8 @@ class Eynollah:
def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement")
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
img_height_model = self.models["enhancement"].layers[-1].output_shape[1]
img_width_model = self.models["enhancement"].layers[-1].output_shape[2]
if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < img_width_model:
@ -522,7 +534,7 @@ class Eynollah:
index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
label_p_pred = self.models["enhancement"].predict(img_patch, verbose=0)
seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0:
@ -697,7 +709,7 @@ class Eynollah:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
self.logger.info("Found %s columns (%s)", num_col, label_p_pred)
@ -715,7 +727,7 @@ class Eynollah:
self.logger.info("Detected %s DPI", dpi)
if self.input_binary:
img = self.imread()
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
prediction_bin = self.do_prediction(True, img, self.models["binarization"], n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0] == 0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
img= np.copy(prediction_bin)
@ -755,7 +767,7 @@ class Eynollah:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
@ -776,7 +788,7 @@ class Eynollah:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
if num_col > self.num_col_upper:
@ -1628,7 +1640,7 @@ class Eynollah:
cont_page = []
if not self.ignore_page_extraction:
img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
img_page_prediction = self.do_prediction(False, img, self.models["page"])
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
##thresh = cv2.dilate(thresh, KERNEL, iterations=3)
@ -1676,7 +1688,7 @@ class Eynollah:
else:
img = self.imread()
img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
img_page_prediction = self.do_prediction(False, img, self.models["page"])
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
@ -1702,7 +1714,7 @@ class Eynollah:
self.logger.debug("enter extract_text_regions")
img_height_h = img.shape[0]
img_width_h = img.shape[1]
model_region = self.model_region_fl if patches else self.model_region_fl_np
model_region = self.models["region_fl"] if patches else self.models["region_fl_np"]
if self.light_version:
thresholding_for_fl_light_version = True
@ -1737,7 +1749,7 @@ class Eynollah:
self.logger.debug("enter extract_text_regions")
img_height_h = img.shape[0]
img_width_h = img.shape[1]
model_region = self.model_region_fl if patches else self.model_region_fl_np
model_region = self.models["region_fl"] if patches else self.models["region_fl_np"]
if not patches:
img = otsu_copy_binary(img)
@ -1958,14 +1970,14 @@ class Eynollah:
img_w = img_org.shape[1]
img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w))
prediction_textline = self.do_prediction(use_patches, img, self.model_textline,
prediction_textline = self.do_prediction(use_patches, img, self.models["textline"],
marginal_of_patch_percent=0.15,
n_batch_inference=3,
thresholding_for_artificial_class_in_light_version=self.textline_light,
threshold_art_class_textline=self.threshold_art_class_textline)
#if not self.textline_light:
#if num_col_classifier==1:
#prediction_textline_nopatch = self.do_prediction(False, img, self.model_textline)
#prediction_textline_nopatch = self.do_prediction(False, img, self.models["textline"])
#prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0
prediction_textline = resize_image(prediction_textline, img_h, img_w)
@ -2036,7 +2048,7 @@ class Eynollah:
#cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0])
prediction_textline_longshot = self.do_prediction(False, img, self.model_textline)
prediction_textline_longshot = self.do_prediction(False, img, self.models["textline"])
prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w)
@ -2069,7 +2081,7 @@ class Eynollah:
img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
img_resized = resize_image(img,img_h_new, img_w_new )
prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_region)
prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.models["region"])
prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h )
image_page, page_coord, cont_page = self.extract_page()
@ -2185,7 +2197,7 @@ class Eynollah:
#if self.input_binary:
#img_bin = np.copy(img_resized)
###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30):
###prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
###prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5)
####print("inside bin ", time.time()-t_bin)
###prediction_bin=prediction_bin[:,:,0]
@ -2200,7 +2212,7 @@ class Eynollah:
###else:
###img_bin = np.copy(img_resized)
if (self.ocr and self.tr) and not self.input_binary:
prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0] == 0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
prediction_bin = prediction_bin.astype(np.uint16)
@ -2232,14 +2244,14 @@ class Eynollah:
self.logger.debug("resized to %dx%d for %d cols",
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
True, img_resized, self.model_region_1_2, n_batch_inference=1,
True, img_resized, self.models["region_1_2"], n_batch_inference=1,
thresholding_for_some_classes_in_light_version=True,
threshold_art_class_layout=self.threshold_art_class_layout)
else:
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3))
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept(
False, self.image_page_org_size, self.model_region_1_2, n_batch_inference=1,
False, self.image_page_org_size, self.models["region_1_2"], n_batch_inference=1,
thresholding_for_artificial_class_in_light_version=True,
threshold_art_class_layout=self.threshold_art_class_layout)
ys = slice(*self.page_coord[0:2])
@ -2253,10 +2265,10 @@ class Eynollah:
self.logger.debug("resized to %dx%d (new_h=%d) for %d cols",
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
True, img_resized, self.model_region_1_2, n_batch_inference=2,
True, img_resized, self.models["region_1_2"], n_batch_inference=2,
thresholding_for_some_classes_in_light_version=True,
threshold_art_class_layout=self.threshold_art_class_layout)
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_region,
###prediction_regions_org = self.do_prediction(True, img_bin, self.models["region"],
###n_batch_inference=3,
###thresholding_for_some_classes_in_light_version=True)
#print("inside 3 ", time.time()-t_in)
@ -2336,7 +2348,7 @@ class Eynollah:
ratio_x=1
img = resize_image(img_org, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
prediction_regions_org_y = self.do_prediction(True, img, self.model_region)
prediction_regions_org_y = self.do_prediction(True, img, self.models["region"])
prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h )
#plt.imshow(prediction_regions_org_y[:,:,0])
@ -2351,7 +2363,7 @@ class Eynollah:
_, _ = find_num_col(img_only_regions, num_col_classifier, self.tables, multiplier=6.0)
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]*(1.2 if is_image_enhanced else 1)))
prediction_regions_org = self.do_prediction(True, img, self.model_region)
prediction_regions_org = self.do_prediction(True, img, self.models["region"])
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
prediction_regions_org=prediction_regions_org[:,:,0]
@ -2359,7 +2371,7 @@ class Eynollah:
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]))
prediction_regions_org2 = self.do_prediction(True, img, self.model_region_p2, marginal_of_patch_percent=0.2)
prediction_regions_org2 = self.do_prediction(True, img, self.models["region_p2"], marginal_of_patch_percent=0.2)
prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h )
mask_zeros2 = (prediction_regions_org2[:,:,0] == 0)
@ -2383,7 +2395,7 @@ class Eynollah:
if self.input_binary:
prediction_bin = np.copy(img_org)
else:
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
@ -2393,7 +2405,7 @@ class Eynollah:
img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
prediction_regions_org = self.do_prediction(True, img, self.model_region)
prediction_regions_org = self.do_prediction(True, img, self.models["region"])
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
prediction_regions_org=prediction_regions_org[:,:,0]
@ -2420,7 +2432,7 @@ class Eynollah:
except:
if self.input_binary:
prediction_bin = np.copy(img_org)
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
@ -2431,14 +2443,14 @@ class Eynollah:
img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
prediction_regions_org = self.do_prediction(True, img, self.model_region)
prediction_regions_org = self.do_prediction(True, img, self.models["region"])
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
prediction_regions_org=prediction_regions_org[:,:,0]
#mask_lines_only=(prediction_regions_org[:,:]==3)*1
#img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1))
#prediction_regions_org = self.do_prediction(True, img, self.model_region)
#prediction_regions_org = self.do_prediction(True, img, self.models["region"])
#prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
#prediction_regions_org = prediction_regions_org[:,:,0]
#prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0
@ -2809,13 +2821,13 @@ class Eynollah:
img_width_h = img_org.shape[1]
patches = False
if self.light_version:
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_table)
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.models["table"])
prediction_table = prediction_table.astype(np.int16)
return prediction_table[:,:,0]
else:
if num_col_classifier < 4 and num_col_classifier > 2:
prediction_table = self.do_prediction(patches, img, self.model_table)
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table)
prediction_table = self.do_prediction(patches, img, self.models["table"])
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"])
pre_updown = cv2.flip(pre_updown, -1)
prediction_table[:,:,0][pre_updown[:,:,0]==1]=1
@ -2834,8 +2846,8 @@ class Eynollah:
xs = slice(w_start, w_start + img.shape[1])
img_new[ys, xs] = img
prediction_ext = self.do_prediction(patches, img_new, self.model_table)
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table)
prediction_ext = self.do_prediction(patches, img_new, self.models["table"])
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"])
pre_updown = cv2.flip(pre_updown, -1)
prediction_table = prediction_ext[ys, xs]
@ -2856,8 +2868,8 @@ class Eynollah:
xs = slice(w_start, w_start + img.shape[1])
img_new[ys, xs] = img
prediction_ext = self.do_prediction(patches, img_new, self.model_table)
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table)
prediction_ext = self.do_prediction(patches, img_new, self.models["table"])
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"])
pre_updown = cv2.flip(pre_updown, -1)
prediction_table = prediction_ext[ys, xs]
@ -2869,10 +2881,10 @@ class Eynollah:
prediction_table = np.zeros(img.shape)
img_w_half = img.shape[1] // 2
pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.model_table)
pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.model_table)
pre_full = self.do_prediction(patches, img[:,:,:], self.model_table)
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table)
pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.models["table"])
pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.models["table"])
pre_full = self.do_prediction(patches, img[:,:,:], self.models["table"])
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"])
pre_updown = cv2.flip(pre_updown, -1)
prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4)
@ -3474,18 +3486,6 @@ class Eynollah:
regions_without_separators_d, regions_fully, regions_without_separators,
polygons_of_marginals, contours_tables)
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def do_order_of_regions_with_model(self, contours_only_text_parent, contours_only_text_parent_h, text_regions_p):
height1 =672#448
@ -3676,7 +3676,7 @@ class Eynollah:
tot_counter += 1
batch.append(j)
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
y_pr = self.model_reading_order.predict(input_1 , verbose=0)
y_pr = self.models["reading_order"].predict(input_1 , verbose=0)
for jb, j in enumerate(batch):
if y_pr[jb][0]>=0.5:
post_list.append(j)
@ -4259,7 +4259,7 @@ class Eynollah:
gc.collect()
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)),
self.prediction_model, self.b_s_ocr, self.num_to_char, textline_light=True)
self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True)
else:
ocr_all_textlines = None
@ -4768,27 +4768,27 @@ class Eynollah:
if len(all_found_textline_polygons):
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, all_box_coord,
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if len(all_found_textline_polygons_marginals_left):
ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left,
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if len(all_found_textline_polygons_marginals_right):
ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right,
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if self.full_layout and len(all_found_textline_polygons):
ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_h, all_box_coord_h,
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if self.full_layout and len(polygons_of_drop_capitals):
ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines(
image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)),
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else:
if self.light_version:
@ -4800,7 +4800,7 @@ class Eynollah:
gc.collect()
torch.cuda.empty_cache()
self.model_ocr.to(self.device)
self.models["ocr"].to(self.device)
ind_tot = 0
#cv2.imwrite('./img_out.png', image_page)
@ -4837,7 +4837,7 @@ class Eynollah:
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
text_ocr = self.return_ocr_of_textline_without_common_section(
img_croped, self.model_ocr, self.processor, self.device, w, h2w_ratio, ind_tot)
img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr)
ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion)