layout: refactor model setup, allow loading custom versions

- simplify definition of (defaults for) model versions
- unify loading of loadable models (depending on mode)
- use `self.models` dict instead of `self.model_*` attributes
- add `model_versions` kwarg / `--model_version` CLI option
This commit is contained in:
Robert Sachunsky 2025-10-10 03:18:09 +02:00
parent 374818de11
commit 4e9a1618c3
3 changed files with 191 additions and 182 deletions

View file

@ -25,6 +25,7 @@ f458e3e
(so CUDA memory gets freed between tests if running on GPU) (so CUDA memory gets freed between tests if running on GPU)
Added: Added:
* :fire: `layout` CLI: new option `--model_version` to override default choices
* test coverage for OCR options in `layout` * test coverage for OCR options in `layout`
* test coverage for table detection in `layout` * test coverage for table detection in `layout`
* CI linting with ruff * CI linting with ruff

View file

@ -202,6 +202,13 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
type=click.Path(exists=True, file_okay=False), type=click.Path(exists=True, file_okay=False),
required=True, required=True,
) )
@click.option(
"--model_version",
"-mv",
help="override default versions of model categories",
type=(str, str),
multiple=True,
)
@click.option( @click.option(
"--save_images", "--save_images",
"-si", "-si",
@ -373,7 +380,7 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
help="Setup a basic console logger", help="Setup a basic console logger",
) )
def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging): def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging):
if setup_logging: if setup_logging:
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
@ -404,6 +411,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
eynollah = Eynollah( eynollah = Eynollah(
model, model,
model_versions=model_version,
extract_only_images=extract_only_images, extract_only_images=extract_only_images,
enable_plotting=enable_plotting, enable_plotting=enable_plotting,
allow_enhancement=allow_enhancement, allow_enhancement=allow_enhancement,

View file

@ -19,7 +19,7 @@ import math
import os import os
import sys import sys
import time import time
from typing import Optional from typing import Dict, List, Optional, Tuple
import atexit import atexit
import warnings import warnings
from functools import partial from functools import partial
@ -180,7 +180,6 @@ class Patches(layers.Layer):
}) })
return config return config
class PatchEncoder(layers.Layer): class PatchEncoder(layers.Layer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(PatchEncoder, self).__init__() super(PatchEncoder, self).__init__()
@ -208,6 +207,7 @@ class Eynollah:
def __init__( def __init__(
self, self,
dir_models : str, dir_models : str,
model_versions: List[Tuple[str, str]] = [],
extract_only_images : bool =False, extract_only_images : bool =False,
enable_plotting : bool = False, enable_plotting : bool = False,
allow_enhancement : bool = False, allow_enhancement : bool = False,
@ -254,6 +254,10 @@ class Eynollah:
self.skip_layout_and_reading_order = skip_layout_and_reading_order self.skip_layout_and_reading_order = skip_layout_and_reading_order
self.ocr = do_ocr self.ocr = do_ocr
self.tr = transformer_ocr self.tr = transformer_ocr
if not batch_size_ocr:
self.b_s_ocr = 8
else:
self.b_s_ocr = int(batch_size_ocr)
if num_col_upper: if num_col_upper:
self.num_col_upper = int(num_col_upper) self.num_col_upper = int(num_col_upper)
else: else:
@ -276,69 +280,6 @@ class Eynollah:
else: else:
self.threshold_art_class_textline = 0.1 self.threshold_art_class_textline = 0.1
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425"
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
self.model_region_dir_p = dir_models + "/eynollah-main-regions-aug-scaling_20210425"
self.model_region_dir_p2 = dir_models + "/eynollah-main-regions-aug-rotation_20210425"
#"/modelens_full_lay_1_3_031124"
#"/modelens_full_lay_13__3_19_241024"
#"/model_full_lay_13_241024"
#"/modelens_full_lay_13_17_231024"
#"/modelens_full_lay_1_2_221024"
#"/eynollah-full-regions-1column_20210425"
self.model_region_dir_fully_np = dir_models + "/modelens_full_lay_1__4_3_091124"
#self.model_region_dir_fully = dir_models + "/eynollah-full-regions-3+column_20210425"
self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915"
self.model_region_dir_p_ens = dir_models + "/eynollah-main-regions-ensembled_20210425"
self.model_region_dir_p_ens_light = dir_models + "/eynollah-main-regions_20220314"
self.model_region_dir_p_ens_light_only_images_extraction = (dir_models +
"/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18"
)
self.model_reading_order_dir = (dir_models +
"/model_eynollah_reading_order_20250824"
#"/model_mb_ro_aug_ens_11"
#"/model_step_3200000_mb_ro"
#"/model_ens_reading_order_machine_based"
#"/model_mb_ro_aug_ens_8"
#"/model_ens_reading_order_machine_based"
)
#"/modelens_12sp_elay_0_3_4__3_6_n"
#"/modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#"/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#"/modelens_1_2_4_5_early_lay_1_2_spaltige"
#"/model_3_eraly_layout_no_patches_1_2_spaltige"
self.model_region_dir_p_1_2_sp_np = dir_models + "/modelens_e_l_all_sp_0_1_2_3_4_171024"
##self.model_region_dir_fully_new = dir_models + "/model_2_full_layout_new_trans"
#"/modelens_full_lay_1_3_031124"
#"/modelens_full_lay_13__3_19_241024"
#"/model_full_lay_13_241024"
#"/modelens_full_lay_13_17_231024"
#"/modelens_full_lay_1_2_221024"
#"/modelens_full_layout_24_till_28"
#"/model_2_full_layout_new_trans"
self.model_region_dir_fully = dir_models + "/modelens_full_lay_1__4_3_091124"
if self.textline_light:
#"/modelens_textline_1_4_16092024"
#"/model_textline_ens_3_4_5_6_artificial"
#"/modelens_textline_1_3_4_20240915"
#"/model_textline_ens_3_4_5_6_artificial"
#"/modelens_textline_9_12_13_14_15"
#"/eynollah-textline_light_20210425"
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"
else:
#"/eynollah-textline_20210425"
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"
if self.ocr and self.tr:
self.model_ocr_dir = dir_models + "/model_eynollah_ocr_trocr_20250919"
elif self.ocr and not self.tr:
self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250930"
if self.tables:
if self.light_version:
self.model_table_dir = dir_models + "/modelens_table_0t4_201124"
else:
self.model_table_dir = dir_models + "/eynollah-tables_20210319"
t_start = time.time() t_start = time.time()
# #gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) # #gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
@ -356,28 +297,124 @@ class Eynollah:
self.logger.warning("no GPU device available") self.logger.warning("no GPU device available")
self.logger.info("Loading models...") self.logger.info("Loading models...")
self.setup_models(dir_models, model_versions)
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
self.model_page = self.our_load_model(self.model_page_dir) @staticmethod
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) def our_load_model(model_file, basedir=""):
self.model_bin = self.our_load_model(self.model_dir_of_binarization) if basedir:
if self.extract_only_images: model_file = os.path.join(basedir, model_file)
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction) if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
else: # prefer SavedModel over HDF5 format if it exists
self.model_textline = self.our_load_model(self.model_textline_dir) model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []):
self.model_versions = {
"enhancement": "eynollah-enhancement_20210425",
"binarization": "eynollah-binarization_20210425",
"col_classifier": "eynollah-column-classifier_20210425",
"page": "model_eynollah_page_extraction_20250915",
#?: "eynollah-main-regions-aug-scaling_20210425",
"region": ( # early layout
"eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else
"eynollah-main-regions_20220314" if self.light_version else
"eynollah-main-regions-ensembled_20210425"),
"region_p2": ( # early layout, non-light, 2nd part
"eynollah-main-regions-aug-rotation_20210425"),
"region_1_2": ( # early layout, light, 1-or-2-column
#"modelens_12sp_elay_0_3_4__3_6_n"
#"modelens_earlylayout_12spaltige_2_3_5_6_7_8"
#"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
#"modelens_1_2_4_5_early_lay_1_2_spaltige"
#"model_3_eraly_layout_no_patches_1_2_spaltige"
"modelens_e_l_all_sp_0_1_2_3_4_171024"),
"region_fl_np": ( # full layout / no patches
#"modelens_full_lay_1_3_031124"
#"modelens_full_lay_13__3_19_241024"
#"model_full_lay_13_241024"
#"modelens_full_lay_13_17_231024"
#"modelens_full_lay_1_2_221024"
#"eynollah-full-regions-1column_20210425"
"modelens_full_lay_1__4_3_091124"),
"region_fl": ( # full layout / with patches
#"eynollah-full-regions-3+column_20210425"
##"model_2_full_layout_new_trans"
#"modelens_full_lay_1_3_031124"
#"modelens_full_lay_13__3_19_241024"
#"model_full_lay_13_241024"
#"modelens_full_lay_13_17_231024"
#"modelens_full_lay_1_2_221024"
#"modelens_full_layout_24_till_28"
#"model_2_full_layout_new_trans"
"modelens_full_lay_1__4_3_091124"),
"reading_order": (
#"model_mb_ro_aug_ens_11"
#"model_step_3200000_mb_ro"
#"model_ens_reading_order_machine_based"
#"model_mb_ro_aug_ens_8"
#"model_ens_reading_order_machine_based"
"model_eynollah_reading_order_20250824"),
"textline": (
#"modelens_textline_1_4_16092024"
#"model_textline_ens_3_4_5_6_artificial"
#"modelens_textline_1_3_4_20240915"
#"model_textline_ens_3_4_5_6_artificial"
#"modelens_textline_9_12_13_14_15"
#"eynollah-textline_light_20210425"
"modelens_textline_0_1__2_4_16092024" if self.textline_light else
#"eynollah-textline_20210425"
"modelens_textline_0_1__2_4_16092024"),
"table": (
None if not self.tables else
"modelens_table_0t4_201124" if self.light_version else
"eynollah-tables_20210319"),
"ocr": (
None if not self.ocr else
"model_eynollah_ocr_trocr_20250919" if self.tr else
"model_eynollah_ocr_cnnrnn_20250930")
}
# override defaults from CLI
for key, val in model_versions:
assert key in self.model_versions, "unknown model category '%s'" % key
self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val)
self.model_versions[key] = val
# load models, depending on modes
loadable = [
"col_classifier",
"binarization",
"page",
"region"
]
if not self.extract_only_images:
loadable.append("textline")
if self.light_version: if self.light_version:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light) loadable.append("region_1_2")
self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np)
else: else:
self.model_region = self.our_load_model(self.model_region_dir_p_ens) loadable.append("region_p2")
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2) # if self.allow_enhancement:?
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) loadable.append("enhancement")
###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new) if self.full_layout:
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) loadable.extend(["region_fl_np",
self.model_region_fl = self.our_load_model(self.model_region_dir_fully) "region_fl"])
if self.reading_order_machine_based: if self.reading_order_machine_based:
self.model_reading_order = self.our_load_model(self.model_reading_order_dir) loadable.append("reading_order")
if self.ocr and self.tr: if self.tables:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) loadable.append("table")
self.models = {name: self.our_load_model(self.model_versions[name], basedir)
for name in loadable
}
if self.ocr:
ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"])
if self.tr:
self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
if torch.cuda.is_available(): if torch.cuda.is_available():
self.logger.info("Using GPU acceleration") self.logger.info("Using GPU acceleration")
self.device = torch.device("cuda:0") self.device = torch.device("cuda:0")
@ -386,54 +423,29 @@ class Eynollah:
self.device = torch.device("cpu") self.device = torch.device("cpu")
#self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") #self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
elif self.ocr and not self.tr: else:
model_ocr = load_model(self.model_ocr_dir , compile=False) ocr_model = load_model(ocr_model_dir, compile=False)
self.models["ocr"] = tf.keras.models.Model(
ocr_model.get_layer(name = "image").input,
ocr_model.get_layer(name = "dense2").output)
self.prediction_model = tf.keras.models.Model( with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file:
model_ocr.get_layer(name = "image").input,
model_ocr.get_layer(name = "dense2").output)
if not batch_size_ocr:
self.b_s_ocr = 8
else:
self.b_s_ocr = int(batch_size_ocr)
with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file:
characters = json.load(config_file) characters = json.load(config_file)
AUTOTUNE = tf.data.AUTOTUNE
# Mapping characters to integers. # Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters. # Mapping integers back to original characters.
self.num_to_char = StringLookup( self.num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
) )
if self.tables:
self.model_table = self.our_load_model(self.model_table_dir)
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
def __del__(self): def __del__(self):
if hasattr(self, 'executor') and getattr(self, 'executor'): if hasattr(self, 'executor') and getattr(self, 'executor'):
self.executor.shutdown() self.executor.shutdown()
for model_name in ['model_page', self.executor = None
'model_classifier', if hasattr(self, 'models') and getattr(self, 'models'):
'model_bin', for model_name in list(self.models):
'model_enhancement', if self.models[model_name]:
'model_region', del self.models[model_name]
'model_region_1_2',
'model_region_p2',
'model_region_fl_np',
'model_region_fl',
'model_textline',
'model_reading_order',
'model_table',
'model_ocr',
'processor']:
if hasattr(self, model_name) and getattr(self, model_name):
delattr(self, model_name)
def cache_images(self, image_filename=None, image_pil=None, dpi=None): def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {} ret = {}
@ -480,8 +492,8 @@ class Eynollah:
def predict_enhancement(self, img): def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement") self.logger.debug("enter predict_enhancement")
img_height_model = self.model_enhancement.layers[-1].output_shape[1] img_height_model = self.models["enhancement"].layers[-1].output_shape[1]
img_width_model = self.model_enhancement.layers[-1].output_shape[2] img_width_model = self.models["enhancement"].layers[-1].output_shape[2]
if img.shape[0] < img_height_model: if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < img_width_model: if img.shape[1] < img_width_model:
@ -522,7 +534,7 @@ class Eynollah:
index_y_d = img_h - img_height_model index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) label_p_pred = self.models["enhancement"].predict(img_patch, verbose=0)
seg = label_p_pred[0, :, :, :] * 255 seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0: if i == 0 and j == 0:
@ -697,7 +709,7 @@ class Eynollah:
img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0) label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1 num_col = np.argmax(label_p_pred[0]) + 1
self.logger.info("Found %s columns (%s)", num_col, label_p_pred) self.logger.info("Found %s columns (%s)", num_col, label_p_pred)
@ -715,7 +727,7 @@ class Eynollah:
self.logger.info("Detected %s DPI", dpi) self.logger.info("Detected %s DPI", dpi)
if self.input_binary: if self.input_binary:
img = self.imread() img = self.imread()
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5) prediction_bin = self.do_prediction(True, img, self.models["binarization"], n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = 255 * (prediction_bin[:,:,0] == 0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
img= np.copy(prediction_bin) img= np.copy(prediction_bin)
@ -755,7 +767,7 @@ class Eynollah:
img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0) label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1 num_col = np.argmax(label_p_pred[0]) + 1
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
@ -776,7 +788,7 @@ class Eynollah:
img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0) label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1 num_col = np.argmax(label_p_pred[0]) + 1
if num_col > self.num_col_upper: if num_col > self.num_col_upper:
@ -1628,7 +1640,7 @@ class Eynollah:
cont_page = [] cont_page = []
if not self.ignore_page_extraction: if not self.ignore_page_extraction:
img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0) img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page) img_page_prediction = self.do_prediction(False, img, self.models["page"])
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0) _, thresh = cv2.threshold(imgray, 0, 255, 0)
##thresh = cv2.dilate(thresh, KERNEL, iterations=3) ##thresh = cv2.dilate(thresh, KERNEL, iterations=3)
@ -1676,7 +1688,7 @@ class Eynollah:
else: else:
img = self.imread() img = self.imread()
img = cv2.GaussianBlur(img, (5, 5), 0) img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page) img_page_prediction = self.do_prediction(False, img, self.models["page"])
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0) _, thresh = cv2.threshold(imgray, 0, 255, 0)
@ -1702,7 +1714,7 @@ class Eynollah:
self.logger.debug("enter extract_text_regions") self.logger.debug("enter extract_text_regions")
img_height_h = img.shape[0] img_height_h = img.shape[0]
img_width_h = img.shape[1] img_width_h = img.shape[1]
model_region = self.model_region_fl if patches else self.model_region_fl_np model_region = self.models["region_fl"] if patches else self.models["region_fl_np"]
if self.light_version: if self.light_version:
thresholding_for_fl_light_version = True thresholding_for_fl_light_version = True
@ -1737,7 +1749,7 @@ class Eynollah:
self.logger.debug("enter extract_text_regions") self.logger.debug("enter extract_text_regions")
img_height_h = img.shape[0] img_height_h = img.shape[0]
img_width_h = img.shape[1] img_width_h = img.shape[1]
model_region = self.model_region_fl if patches else self.model_region_fl_np model_region = self.models["region_fl"] if patches else self.models["region_fl_np"]
if not patches: if not patches:
img = otsu_copy_binary(img) img = otsu_copy_binary(img)
@ -1958,14 +1970,14 @@ class Eynollah:
img_w = img_org.shape[1] img_w = img_org.shape[1]
img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w))
prediction_textline = self.do_prediction(use_patches, img, self.model_textline, prediction_textline = self.do_prediction(use_patches, img, self.models["textline"],
marginal_of_patch_percent=0.15, marginal_of_patch_percent=0.15,
n_batch_inference=3, n_batch_inference=3,
thresholding_for_artificial_class_in_light_version=self.textline_light, thresholding_for_artificial_class_in_light_version=self.textline_light,
threshold_art_class_textline=self.threshold_art_class_textline) threshold_art_class_textline=self.threshold_art_class_textline)
#if not self.textline_light: #if not self.textline_light:
#if num_col_classifier==1: #if num_col_classifier==1:
#prediction_textline_nopatch = self.do_prediction(False, img, self.model_textline) #prediction_textline_nopatch = self.do_prediction(False, img, self.models["textline"])
#prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0 #prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0
prediction_textline = resize_image(prediction_textline, img_h, img_w) prediction_textline = resize_image(prediction_textline, img_h, img_w)
@ -2036,7 +2048,7 @@ class Eynollah:
#cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0]) #cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0])
prediction_textline_longshot = self.do_prediction(False, img, self.model_textline) prediction_textline_longshot = self.do_prediction(False, img, self.models["textline"])
prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w) prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w)
@ -2069,7 +2081,7 @@ class Eynollah:
img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new) img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
img_resized = resize_image(img,img_h_new, img_w_new ) img_resized = resize_image(img,img_h_new, img_w_new )
prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_region) prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.models["region"])
prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h ) prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h )
image_page, page_coord, cont_page = self.extract_page() image_page, page_coord, cont_page = self.extract_page()
@ -2185,7 +2197,7 @@ class Eynollah:
#if self.input_binary: #if self.input_binary:
#img_bin = np.copy(img_resized) #img_bin = np.copy(img_resized)
###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30): ###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30):
###prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5) ###prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5)
####print("inside bin ", time.time()-t_bin) ####print("inside bin ", time.time()-t_bin)
###prediction_bin=prediction_bin[:,:,0] ###prediction_bin=prediction_bin[:,:,0]
@ -2200,7 +2212,7 @@ class Eynollah:
###else: ###else:
###img_bin = np.copy(img_resized) ###img_bin = np.copy(img_resized)
if (self.ocr and self.tr) and not self.input_binary: if (self.ocr and self.tr) and not self.input_binary:
prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5) prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0] == 0) prediction_bin = 255 * (prediction_bin[:,:,0] == 0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
prediction_bin = prediction_bin.astype(np.uint16) prediction_bin = prediction_bin.astype(np.uint16)
@ -2232,14 +2244,14 @@ class Eynollah:
self.logger.debug("resized to %dx%d for %d cols", self.logger.debug("resized to %dx%d for %d cols",
img_resized.shape[1], img_resized.shape[0], num_col_classifier) img_resized.shape[1], img_resized.shape[0], num_col_classifier)
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
True, img_resized, self.model_region_1_2, n_batch_inference=1, True, img_resized, self.models["region_1_2"], n_batch_inference=1,
thresholding_for_some_classes_in_light_version=True, thresholding_for_some_classes_in_light_version=True,
threshold_art_class_layout=self.threshold_art_class_layout) threshold_art_class_layout=self.threshold_art_class_layout)
else: else:
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3))
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1])) confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept( prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept(
False, self.image_page_org_size, self.model_region_1_2, n_batch_inference=1, False, self.image_page_org_size, self.models["region_1_2"], n_batch_inference=1,
thresholding_for_artificial_class_in_light_version=True, thresholding_for_artificial_class_in_light_version=True,
threshold_art_class_layout=self.threshold_art_class_layout) threshold_art_class_layout=self.threshold_art_class_layout)
ys = slice(*self.page_coord[0:2]) ys = slice(*self.page_coord[0:2])
@ -2253,10 +2265,10 @@ class Eynollah:
self.logger.debug("resized to %dx%d (new_h=%d) for %d cols", self.logger.debug("resized to %dx%d (new_h=%d) for %d cols",
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier) img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
True, img_resized, self.model_region_1_2, n_batch_inference=2, True, img_resized, self.models["region_1_2"], n_batch_inference=2,
thresholding_for_some_classes_in_light_version=True, thresholding_for_some_classes_in_light_version=True,
threshold_art_class_layout=self.threshold_art_class_layout) threshold_art_class_layout=self.threshold_art_class_layout)
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_region, ###prediction_regions_org = self.do_prediction(True, img_bin, self.models["region"],
###n_batch_inference=3, ###n_batch_inference=3,
###thresholding_for_some_classes_in_light_version=True) ###thresholding_for_some_classes_in_light_version=True)
#print("inside 3 ", time.time()-t_in) #print("inside 3 ", time.time()-t_in)
@ -2336,7 +2348,7 @@ class Eynollah:
ratio_x=1 ratio_x=1
img = resize_image(img_org, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) img = resize_image(img_org, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
prediction_regions_org_y = self.do_prediction(True, img, self.model_region) prediction_regions_org_y = self.do_prediction(True, img, self.models["region"])
prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h ) prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h )
#plt.imshow(prediction_regions_org_y[:,:,0]) #plt.imshow(prediction_regions_org_y[:,:,0])
@ -2351,7 +2363,7 @@ class Eynollah:
_, _ = find_num_col(img_only_regions, num_col_classifier, self.tables, multiplier=6.0) _, _ = find_num_col(img_only_regions, num_col_classifier, self.tables, multiplier=6.0)
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]*(1.2 if is_image_enhanced else 1))) img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]*(1.2 if is_image_enhanced else 1)))
prediction_regions_org = self.do_prediction(True, img, self.model_region) prediction_regions_org = self.do_prediction(True, img, self.models["region"])
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
prediction_regions_org=prediction_regions_org[:,:,0] prediction_regions_org=prediction_regions_org[:,:,0]
@ -2359,7 +2371,7 @@ class Eynollah:
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1])) img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]))
prediction_regions_org2 = self.do_prediction(True, img, self.model_region_p2, marginal_of_patch_percent=0.2) prediction_regions_org2 = self.do_prediction(True, img, self.models["region_p2"], marginal_of_patch_percent=0.2)
prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h ) prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h )
mask_zeros2 = (prediction_regions_org2[:,:,0] == 0) mask_zeros2 = (prediction_regions_org2[:,:,0] == 0)
@ -2383,7 +2395,7 @@ class Eynollah:
if self.input_binary: if self.input_binary:
prediction_bin = np.copy(img_org) prediction_bin = np.copy(img_org)
else: else:
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5) prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
@ -2393,7 +2405,7 @@ class Eynollah:
img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
prediction_regions_org = self.do_prediction(True, img, self.model_region) prediction_regions_org = self.do_prediction(True, img, self.models["region"])
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
prediction_regions_org=prediction_regions_org[:,:,0] prediction_regions_org=prediction_regions_org[:,:,0]
@ -2420,7 +2432,7 @@ class Eynollah:
except: except:
if self.input_binary: if self.input_binary:
prediction_bin = np.copy(img_org) prediction_bin = np.copy(img_org)
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5) prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h ) prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
@ -2431,14 +2443,14 @@ class Eynollah:
img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x)) img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
prediction_regions_org = self.do_prediction(True, img, self.model_region) prediction_regions_org = self.do_prediction(True, img, self.models["region"])
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
prediction_regions_org=prediction_regions_org[:,:,0] prediction_regions_org=prediction_regions_org[:,:,0]
#mask_lines_only=(prediction_regions_org[:,:]==3)*1 #mask_lines_only=(prediction_regions_org[:,:]==3)*1
#img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1)) #img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1))
#prediction_regions_org = self.do_prediction(True, img, self.model_region) #prediction_regions_org = self.do_prediction(True, img, self.models["region"])
#prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) #prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
#prediction_regions_org = prediction_regions_org[:,:,0] #prediction_regions_org = prediction_regions_org[:,:,0]
#prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0 #prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0
@ -2809,13 +2821,13 @@ class Eynollah:
img_width_h = img_org.shape[1] img_width_h = img_org.shape[1]
patches = False patches = False
if self.light_version: if self.light_version:
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_table) prediction_table, _ = self.do_prediction_new_concept(patches, img, self.models["table"])
prediction_table = prediction_table.astype(np.int16) prediction_table = prediction_table.astype(np.int16)
return prediction_table[:,:,0] return prediction_table[:,:,0]
else: else:
if num_col_classifier < 4 and num_col_classifier > 2: if num_col_classifier < 4 and num_col_classifier > 2:
prediction_table = self.do_prediction(patches, img, self.model_table) prediction_table = self.do_prediction(patches, img, self.models["table"])
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table) pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"])
pre_updown = cv2.flip(pre_updown, -1) pre_updown = cv2.flip(pre_updown, -1)
prediction_table[:,:,0][pre_updown[:,:,0]==1]=1 prediction_table[:,:,0][pre_updown[:,:,0]==1]=1
@ -2834,8 +2846,8 @@ class Eynollah:
xs = slice(w_start, w_start + img.shape[1]) xs = slice(w_start, w_start + img.shape[1])
img_new[ys, xs] = img img_new[ys, xs] = img
prediction_ext = self.do_prediction(patches, img_new, self.model_table) prediction_ext = self.do_prediction(patches, img_new, self.models["table"])
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table) pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"])
pre_updown = cv2.flip(pre_updown, -1) pre_updown = cv2.flip(pre_updown, -1)
prediction_table = prediction_ext[ys, xs] prediction_table = prediction_ext[ys, xs]
@ -2856,8 +2868,8 @@ class Eynollah:
xs = slice(w_start, w_start + img.shape[1]) xs = slice(w_start, w_start + img.shape[1])
img_new[ys, xs] = img img_new[ys, xs] = img
prediction_ext = self.do_prediction(patches, img_new, self.model_table) prediction_ext = self.do_prediction(patches, img_new, self.models["table"])
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table) pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"])
pre_updown = cv2.flip(pre_updown, -1) pre_updown = cv2.flip(pre_updown, -1)
prediction_table = prediction_ext[ys, xs] prediction_table = prediction_ext[ys, xs]
@ -2869,10 +2881,10 @@ class Eynollah:
prediction_table = np.zeros(img.shape) prediction_table = np.zeros(img.shape)
img_w_half = img.shape[1] // 2 img_w_half = img.shape[1] // 2
pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.model_table) pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.models["table"])
pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.model_table) pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.models["table"])
pre_full = self.do_prediction(patches, img[:,:,:], self.model_table) pre_full = self.do_prediction(patches, img[:,:,:], self.models["table"])
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table) pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"])
pre_updown = cv2.flip(pre_updown, -1) pre_updown = cv2.flip(pre_updown, -1)
prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4) prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4)
@ -3474,18 +3486,6 @@ class Eynollah:
regions_without_separators_d, regions_fully, regions_without_separators, regions_without_separators_d, regions_fully, regions_without_separators,
polygons_of_marginals, contours_tables) polygons_of_marginals, contours_tables)
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def do_order_of_regions_with_model(self, contours_only_text_parent, contours_only_text_parent_h, text_regions_p): def do_order_of_regions_with_model(self, contours_only_text_parent, contours_only_text_parent_h, text_regions_p):
height1 =672#448 height1 =672#448
@ -3676,7 +3676,7 @@ class Eynollah:
tot_counter += 1 tot_counter += 1
batch.append(j) batch.append(j)
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
y_pr = self.model_reading_order.predict(input_1 , verbose=0) y_pr = self.models["reading_order"].predict(input_1 , verbose=0)
for jb, j in enumerate(batch): for jb, j in enumerate(batch):
if y_pr[jb][0]>=0.5: if y_pr[jb][0]>=0.5:
post_list.append(j) post_list.append(j)
@ -4259,7 +4259,7 @@ class Eynollah:
gc.collect() gc.collect()
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)), image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)),
self.prediction_model, self.b_s_ocr, self.num_to_char, textline_light=True) self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True)
else: else:
ocr_all_textlines = None ocr_all_textlines = None
@ -4768,27 +4768,27 @@ class Eynollah:
if len(all_found_textline_polygons): if len(all_found_textline_polygons):
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, all_box_coord, image_page, all_found_textline_polygons, all_box_coord,
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if len(all_found_textline_polygons_marginals_left): if len(all_found_textline_polygons_marginals_left):
ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left, image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left,
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if len(all_found_textline_polygons_marginals_right): if len(all_found_textline_polygons_marginals_right):
ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right, image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right,
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if self.full_layout and len(all_found_textline_polygons): if self.full_layout and len(all_found_textline_polygons):
ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_h, all_box_coord_h, image_page, all_found_textline_polygons_h, all_box_coord_h,
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if self.full_layout and len(polygons_of_drop_capitals): if self.full_layout and len(polygons_of_drop_capitals):
ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines(
image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)), image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)),
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else: else:
if self.light_version: if self.light_version:
@ -4800,7 +4800,7 @@ class Eynollah:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.model_ocr.to(self.device) self.models["ocr"].to(self.device)
ind_tot = 0 ind_tot = 0
#cv2.imwrite('./img_out.png', image_page) #cv2.imwrite('./img_out.png', image_page)
@ -4837,7 +4837,7 @@ class Eynollah:
img_croped = img_poly_on_img[y:y+h, x:x+w, :] img_croped = img_poly_on_img[y:y+h, x:x+w, :]
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
text_ocr = self.return_ocr_of_textline_without_common_section( text_ocr = self.return_ocr_of_textline_without_common_section(
img_croped, self.model_ocr, self.processor, self.device, w, h2w_ratio, ind_tot) img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr) ocr_textline_in_textregion.append(text_ocr)
ind_tot = ind_tot +1 ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion) ocr_all_textlines.append(ocr_textline_in_textregion)