mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-15 19:09:58 +02:00
layout: refactor model setup, allow loading custom versions
- simplify definition of (defaults for) model versions - unify loading of loadable models (depending on mode) - use `self.models` dict instead of `self.model_*` attributes - add `model_versions` kwarg / `--model_version` CLI option
This commit is contained in:
parent
374818de11
commit
4e9a1618c3
3 changed files with 191 additions and 182 deletions
|
@ -25,6 +25,7 @@ f458e3e
|
|||
(so CUDA memory gets freed between tests if running on GPU)
|
||||
|
||||
Added:
|
||||
* :fire: `layout` CLI: new option `--model_version` to override default choices
|
||||
* test coverage for OCR options in `layout`
|
||||
* test coverage for table detection in `layout`
|
||||
* CI linting with ruff
|
||||
|
|
|
@ -202,6 +202,13 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
|
|||
type=click.Path(exists=True, file_okay=False),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--model_version",
|
||||
"-mv",
|
||||
help="override default versions of model categories",
|
||||
type=(str, str),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--save_images",
|
||||
"-si",
|
||||
|
@ -373,7 +380,7 @@ def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_low
|
|||
help="Setup a basic console logger",
|
||||
)
|
||||
|
||||
def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging):
|
||||
def layout(image, out, overwrite, dir_in, model, model_version, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, transformer_ocr, batch_size_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level, setup_logging):
|
||||
if setup_logging:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
@ -404,6 +411,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
eynollah = Eynollah(
|
||||
model,
|
||||
model_versions=model_version,
|
||||
extract_only_images=extract_only_images,
|
||||
enable_plotting=enable_plotting,
|
||||
allow_enhancement=allow_enhancement,
|
||||
|
|
|
@ -19,7 +19,7 @@ import math
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import atexit
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
@ -180,7 +180,6 @@ class Patches(layers.Layer):
|
|||
})
|
||||
return config
|
||||
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(PatchEncoder, self).__init__()
|
||||
|
@ -208,6 +207,7 @@ class Eynollah:
|
|||
def __init__(
|
||||
self,
|
||||
dir_models : str,
|
||||
model_versions: List[Tuple[str, str]] = [],
|
||||
extract_only_images : bool =False,
|
||||
enable_plotting : bool = False,
|
||||
allow_enhancement : bool = False,
|
||||
|
@ -254,6 +254,10 @@ class Eynollah:
|
|||
self.skip_layout_and_reading_order = skip_layout_and_reading_order
|
||||
self.ocr = do_ocr
|
||||
self.tr = transformer_ocr
|
||||
if not batch_size_ocr:
|
||||
self.b_s_ocr = 8
|
||||
else:
|
||||
self.b_s_ocr = int(batch_size_ocr)
|
||||
if num_col_upper:
|
||||
self.num_col_upper = int(num_col_upper)
|
||||
else:
|
||||
|
@ -275,69 +279,6 @@ class Eynollah:
|
|||
self.threshold_art_class_textline = float(threshold_art_class_textline)
|
||||
else:
|
||||
self.threshold_art_class_textline = 0.1
|
||||
|
||||
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
|
||||
self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425"
|
||||
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
|
||||
self.model_region_dir_p = dir_models + "/eynollah-main-regions-aug-scaling_20210425"
|
||||
self.model_region_dir_p2 = dir_models + "/eynollah-main-regions-aug-rotation_20210425"
|
||||
#"/modelens_full_lay_1_3_031124"
|
||||
#"/modelens_full_lay_13__3_19_241024"
|
||||
#"/model_full_lay_13_241024"
|
||||
#"/modelens_full_lay_13_17_231024"
|
||||
#"/modelens_full_lay_1_2_221024"
|
||||
#"/eynollah-full-regions-1column_20210425"
|
||||
self.model_region_dir_fully_np = dir_models + "/modelens_full_lay_1__4_3_091124"
|
||||
#self.model_region_dir_fully = dir_models + "/eynollah-full-regions-3+column_20210425"
|
||||
self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915"
|
||||
self.model_region_dir_p_ens = dir_models + "/eynollah-main-regions-ensembled_20210425"
|
||||
self.model_region_dir_p_ens_light = dir_models + "/eynollah-main-regions_20220314"
|
||||
self.model_region_dir_p_ens_light_only_images_extraction = (dir_models +
|
||||
"/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18"
|
||||
)
|
||||
self.model_reading_order_dir = (dir_models +
|
||||
"/model_eynollah_reading_order_20250824"
|
||||
#"/model_mb_ro_aug_ens_11"
|
||||
#"/model_step_3200000_mb_ro"
|
||||
#"/model_ens_reading_order_machine_based"
|
||||
#"/model_mb_ro_aug_ens_8"
|
||||
#"/model_ens_reading_order_machine_based"
|
||||
)
|
||||
#"/modelens_12sp_elay_0_3_4__3_6_n"
|
||||
#"/modelens_earlylayout_12spaltige_2_3_5_6_7_8"
|
||||
#"/modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
|
||||
#"/modelens_1_2_4_5_early_lay_1_2_spaltige"
|
||||
#"/model_3_eraly_layout_no_patches_1_2_spaltige"
|
||||
self.model_region_dir_p_1_2_sp_np = dir_models + "/modelens_e_l_all_sp_0_1_2_3_4_171024"
|
||||
##self.model_region_dir_fully_new = dir_models + "/model_2_full_layout_new_trans"
|
||||
#"/modelens_full_lay_1_3_031124"
|
||||
#"/modelens_full_lay_13__3_19_241024"
|
||||
#"/model_full_lay_13_241024"
|
||||
#"/modelens_full_lay_13_17_231024"
|
||||
#"/modelens_full_lay_1_2_221024"
|
||||
#"/modelens_full_layout_24_till_28"
|
||||
#"/model_2_full_layout_new_trans"
|
||||
self.model_region_dir_fully = dir_models + "/modelens_full_lay_1__4_3_091124"
|
||||
if self.textline_light:
|
||||
#"/modelens_textline_1_4_16092024"
|
||||
#"/model_textline_ens_3_4_5_6_artificial"
|
||||
#"/modelens_textline_1_3_4_20240915"
|
||||
#"/model_textline_ens_3_4_5_6_artificial"
|
||||
#"/modelens_textline_9_12_13_14_15"
|
||||
#"/eynollah-textline_light_20210425"
|
||||
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"
|
||||
else:
|
||||
#"/eynollah-textline_20210425"
|
||||
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"
|
||||
if self.ocr and self.tr:
|
||||
self.model_ocr_dir = dir_models + "/model_eynollah_ocr_trocr_20250919"
|
||||
elif self.ocr and not self.tr:
|
||||
self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250930"
|
||||
if self.tables:
|
||||
if self.light_version:
|
||||
self.model_table_dir = dir_models + "/modelens_table_0t4_201124"
|
||||
else:
|
||||
self.model_table_dir = dir_models + "/eynollah-tables_20210319"
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
|
@ -356,28 +297,124 @@ class Eynollah:
|
|||
self.logger.warning("no GPU device available")
|
||||
|
||||
self.logger.info("Loading models...")
|
||||
|
||||
self.model_page = self.our_load_model(self.model_page_dir)
|
||||
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
|
||||
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
|
||||
if self.extract_only_images:
|
||||
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction)
|
||||
else:
|
||||
self.model_textline = self.our_load_model(self.model_textline_dir)
|
||||
self.setup_models(dir_models, model_versions)
|
||||
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
||||
|
||||
@staticmethod
|
||||
def our_load_model(model_file, basedir=""):
|
||||
if basedir:
|
||||
model_file = os.path.join(basedir, model_file)
|
||||
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_file = model_file[:-3]
|
||||
try:
|
||||
model = load_model(model_file, compile=False)
|
||||
except:
|
||||
model = load_model(model_file, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
return model
|
||||
|
||||
def setup_models(self, basedir: Path, model_versions: List[Tuple[str, str]] = []):
|
||||
self.model_versions = {
|
||||
"enhancement": "eynollah-enhancement_20210425",
|
||||
"binarization": "eynollah-binarization_20210425",
|
||||
"col_classifier": "eynollah-column-classifier_20210425",
|
||||
"page": "model_eynollah_page_extraction_20250915",
|
||||
#?: "eynollah-main-regions-aug-scaling_20210425",
|
||||
"region": ( # early layout
|
||||
"eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" if self.extract_only_images else
|
||||
"eynollah-main-regions_20220314" if self.light_version else
|
||||
"eynollah-main-regions-ensembled_20210425"),
|
||||
"region_p2": ( # early layout, non-light, 2nd part
|
||||
"eynollah-main-regions-aug-rotation_20210425"),
|
||||
"region_1_2": ( # early layout, light, 1-or-2-column
|
||||
#"modelens_12sp_elay_0_3_4__3_6_n"
|
||||
#"modelens_earlylayout_12spaltige_2_3_5_6_7_8"
|
||||
#"modelens_early12_sp_2_3_5_6_7_8_9_10_12_14_15_16_18"
|
||||
#"modelens_1_2_4_5_early_lay_1_2_spaltige"
|
||||
#"model_3_eraly_layout_no_patches_1_2_spaltige"
|
||||
"modelens_e_l_all_sp_0_1_2_3_4_171024"),
|
||||
"region_fl_np": ( # full layout / no patches
|
||||
#"modelens_full_lay_1_3_031124"
|
||||
#"modelens_full_lay_13__3_19_241024"
|
||||
#"model_full_lay_13_241024"
|
||||
#"modelens_full_lay_13_17_231024"
|
||||
#"modelens_full_lay_1_2_221024"
|
||||
#"eynollah-full-regions-1column_20210425"
|
||||
"modelens_full_lay_1__4_3_091124"),
|
||||
"region_fl": ( # full layout / with patches
|
||||
#"eynollah-full-regions-3+column_20210425"
|
||||
##"model_2_full_layout_new_trans"
|
||||
#"modelens_full_lay_1_3_031124"
|
||||
#"modelens_full_lay_13__3_19_241024"
|
||||
#"model_full_lay_13_241024"
|
||||
#"modelens_full_lay_13_17_231024"
|
||||
#"modelens_full_lay_1_2_221024"
|
||||
#"modelens_full_layout_24_till_28"
|
||||
#"model_2_full_layout_new_trans"
|
||||
"modelens_full_lay_1__4_3_091124"),
|
||||
"reading_order": (
|
||||
#"model_mb_ro_aug_ens_11"
|
||||
#"model_step_3200000_mb_ro"
|
||||
#"model_ens_reading_order_machine_based"
|
||||
#"model_mb_ro_aug_ens_8"
|
||||
#"model_ens_reading_order_machine_based"
|
||||
"model_eynollah_reading_order_20250824"),
|
||||
"textline": (
|
||||
#"modelens_textline_1_4_16092024"
|
||||
#"model_textline_ens_3_4_5_6_artificial"
|
||||
#"modelens_textline_1_3_4_20240915"
|
||||
#"model_textline_ens_3_4_5_6_artificial"
|
||||
#"modelens_textline_9_12_13_14_15"
|
||||
#"eynollah-textline_light_20210425"
|
||||
"modelens_textline_0_1__2_4_16092024" if self.textline_light else
|
||||
#"eynollah-textline_20210425"
|
||||
"modelens_textline_0_1__2_4_16092024"),
|
||||
"table": (
|
||||
None if not self.tables else
|
||||
"modelens_table_0t4_201124" if self.light_version else
|
||||
"eynollah-tables_20210319"),
|
||||
"ocr": (
|
||||
None if not self.ocr else
|
||||
"model_eynollah_ocr_trocr_20250919" if self.tr else
|
||||
"model_eynollah_ocr_cnnrnn_20250930")
|
||||
}
|
||||
# override defaults from CLI
|
||||
for key, val in model_versions:
|
||||
assert key in self.model_versions, "unknown model category '%s'" % key
|
||||
self.logger.warning("overriding default model %s version %s to %s", key, self.model_versions[key], val)
|
||||
self.model_versions[key] = val
|
||||
# load models, depending on modes
|
||||
loadable = [
|
||||
"col_classifier",
|
||||
"binarization",
|
||||
"page",
|
||||
"region"
|
||||
]
|
||||
if not self.extract_only_images:
|
||||
loadable.append("textline")
|
||||
if self.light_version:
|
||||
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light)
|
||||
self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np)
|
||||
loadable.append("region_1_2")
|
||||
else:
|
||||
self.model_region = self.our_load_model(self.model_region_dir_p_ens)
|
||||
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
|
||||
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
|
||||
###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new)
|
||||
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
|
||||
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
|
||||
loadable.append("region_p2")
|
||||
# if self.allow_enhancement:?
|
||||
loadable.append("enhancement")
|
||||
if self.full_layout:
|
||||
loadable.extend(["region_fl_np",
|
||||
"region_fl"])
|
||||
if self.reading_order_machine_based:
|
||||
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
|
||||
if self.ocr and self.tr:
|
||||
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
|
||||
loadable.append("reading_order")
|
||||
if self.tables:
|
||||
loadable.append("table")
|
||||
|
||||
self.models = {name: self.our_load_model(self.model_versions[name], basedir)
|
||||
for name in loadable
|
||||
}
|
||||
|
||||
if self.ocr:
|
||||
ocr_model_dir = os.path.join(basedir, self.model_versions["ocr"])
|
||||
if self.tr:
|
||||
self.models["ocr"] = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir)
|
||||
if torch.cuda.is_available():
|
||||
self.logger.info("Using GPU acceleration")
|
||||
self.device = torch.device("cuda:0")
|
||||
|
@ -386,54 +423,29 @@ class Eynollah:
|
|||
self.device = torch.device("cpu")
|
||||
#self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
||||
elif self.ocr and not self.tr:
|
||||
model_ocr = load_model(self.model_ocr_dir , compile=False)
|
||||
|
||||
self.prediction_model = tf.keras.models.Model(
|
||||
model_ocr.get_layer(name = "image").input,
|
||||
model_ocr.get_layer(name = "dense2").output)
|
||||
if not batch_size_ocr:
|
||||
self.b_s_ocr = 8
|
||||
else:
|
||||
self.b_s_ocr = int(batch_size_ocr)
|
||||
else:
|
||||
ocr_model = load_model(ocr_model_dir, compile=False)
|
||||
self.models["ocr"] = tf.keras.models.Model(
|
||||
ocr_model.get_layer(name = "image").input,
|
||||
ocr_model.get_layer(name = "dense2").output)
|
||||
|
||||
with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file:
|
||||
with open(os.path.join(ocr_model_dir, "characters_org.txt"), "r") as config_file:
|
||||
characters = json.load(config_file)
|
||||
|
||||
AUTOTUNE = tf.data.AUTOTUNE
|
||||
|
||||
# Mapping characters to integers.
|
||||
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
||||
|
||||
# Mapping integers back to original characters.
|
||||
self.num_to_char = StringLookup(
|
||||
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
||||
)
|
||||
|
||||
if self.tables:
|
||||
self.model_table = self.our_load_model(self.model_table_dir)
|
||||
|
||||
self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)")
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'executor') and getattr(self, 'executor'):
|
||||
self.executor.shutdown()
|
||||
for model_name in ['model_page',
|
||||
'model_classifier',
|
||||
'model_bin',
|
||||
'model_enhancement',
|
||||
'model_region',
|
||||
'model_region_1_2',
|
||||
'model_region_p2',
|
||||
'model_region_fl_np',
|
||||
'model_region_fl',
|
||||
'model_textline',
|
||||
'model_reading_order',
|
||||
'model_table',
|
||||
'model_ocr',
|
||||
'processor']:
|
||||
if hasattr(self, model_name) and getattr(self, model_name):
|
||||
delattr(self, model_name)
|
||||
self.executor = None
|
||||
if hasattr(self, 'models') and getattr(self, 'models'):
|
||||
for model_name in list(self.models):
|
||||
if self.models[model_name]:
|
||||
del self.models[model_name]
|
||||
|
||||
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
|
||||
ret = {}
|
||||
|
@ -480,8 +492,8 @@ class Eynollah:
|
|||
def predict_enhancement(self, img):
|
||||
self.logger.debug("enter predict_enhancement")
|
||||
|
||||
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
|
||||
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
|
||||
img_height_model = self.models["enhancement"].layers[-1].output_shape[1]
|
||||
img_width_model = self.models["enhancement"].layers[-1].output_shape[2]
|
||||
if img.shape[0] < img_height_model:
|
||||
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
|
||||
if img.shape[1] < img_width_model:
|
||||
|
@ -522,7 +534,7 @@ class Eynollah:
|
|||
index_y_d = img_h - img_height_model
|
||||
|
||||
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
|
||||
label_p_pred = self.models["enhancement"].predict(img_patch, verbose=0)
|
||||
seg = label_p_pred[0, :, :, :] * 255
|
||||
|
||||
if i == 0 and j == 0:
|
||||
|
@ -697,7 +709,7 @@ class Eynollah:
|
|||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||
|
||||
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||
label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
|
||||
num_col = np.argmax(label_p_pred[0]) + 1
|
||||
|
||||
self.logger.info("Found %s columns (%s)", num_col, label_p_pred)
|
||||
|
@ -715,7 +727,7 @@ class Eynollah:
|
|||
self.logger.info("Detected %s DPI", dpi)
|
||||
if self.input_binary:
|
||||
img = self.imread()
|
||||
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = self.do_prediction(True, img, self.models["binarization"], n_batch_inference=5)
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0] == 0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
|
||||
img= np.copy(prediction_bin)
|
||||
|
@ -755,7 +767,7 @@ class Eynollah:
|
|||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||
|
||||
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||
label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
|
||||
num_col = np.argmax(label_p_pred[0]) + 1
|
||||
|
||||
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
|
||||
|
@ -776,7 +788,7 @@ class Eynollah:
|
|||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||
|
||||
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||
label_p_pred = self.models["col_classifier"].predict(img_in, verbose=0)
|
||||
num_col = np.argmax(label_p_pred[0]) + 1
|
||||
|
||||
if num_col > self.num_col_upper:
|
||||
|
@ -1628,7 +1640,7 @@ class Eynollah:
|
|||
cont_page = []
|
||||
if not self.ignore_page_extraction:
|
||||
img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0)
|
||||
img_page_prediction = self.do_prediction(False, img, self.model_page)
|
||||
img_page_prediction = self.do_prediction(False, img, self.models["page"])
|
||||
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
|
||||
_, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
##thresh = cv2.dilate(thresh, KERNEL, iterations=3)
|
||||
|
@ -1676,7 +1688,7 @@ class Eynollah:
|
|||
else:
|
||||
img = self.imread()
|
||||
img = cv2.GaussianBlur(img, (5, 5), 0)
|
||||
img_page_prediction = self.do_prediction(False, img, self.model_page)
|
||||
img_page_prediction = self.do_prediction(False, img, self.models["page"])
|
||||
|
||||
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
|
||||
_, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
|
@ -1702,7 +1714,7 @@ class Eynollah:
|
|||
self.logger.debug("enter extract_text_regions")
|
||||
img_height_h = img.shape[0]
|
||||
img_width_h = img.shape[1]
|
||||
model_region = self.model_region_fl if patches else self.model_region_fl_np
|
||||
model_region = self.models["region_fl"] if patches else self.models["region_fl_np"]
|
||||
|
||||
if self.light_version:
|
||||
thresholding_for_fl_light_version = True
|
||||
|
@ -1737,7 +1749,7 @@ class Eynollah:
|
|||
self.logger.debug("enter extract_text_regions")
|
||||
img_height_h = img.shape[0]
|
||||
img_width_h = img.shape[1]
|
||||
model_region = self.model_region_fl if patches else self.model_region_fl_np
|
||||
model_region = self.models["region_fl"] if patches else self.models["region_fl_np"]
|
||||
|
||||
if not patches:
|
||||
img = otsu_copy_binary(img)
|
||||
|
@ -1958,14 +1970,14 @@ class Eynollah:
|
|||
img_w = img_org.shape[1]
|
||||
img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w))
|
||||
|
||||
prediction_textline = self.do_prediction(use_patches, img, self.model_textline,
|
||||
prediction_textline = self.do_prediction(use_patches, img, self.models["textline"],
|
||||
marginal_of_patch_percent=0.15,
|
||||
n_batch_inference=3,
|
||||
thresholding_for_artificial_class_in_light_version=self.textline_light,
|
||||
threshold_art_class_textline=self.threshold_art_class_textline)
|
||||
#if not self.textline_light:
|
||||
#if num_col_classifier==1:
|
||||
#prediction_textline_nopatch = self.do_prediction(False, img, self.model_textline)
|
||||
#prediction_textline_nopatch = self.do_prediction(False, img, self.models["textline"])
|
||||
#prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0
|
||||
|
||||
prediction_textline = resize_image(prediction_textline, img_h, img_w)
|
||||
|
@ -2036,7 +2048,7 @@ class Eynollah:
|
|||
|
||||
#cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0])
|
||||
|
||||
prediction_textline_longshot = self.do_prediction(False, img, self.model_textline)
|
||||
prediction_textline_longshot = self.do_prediction(False, img, self.models["textline"])
|
||||
prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w)
|
||||
|
||||
|
||||
|
@ -2069,7 +2081,7 @@ class Eynollah:
|
|||
img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
|
||||
img_resized = resize_image(img,img_h_new, img_w_new )
|
||||
|
||||
prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.model_region)
|
||||
prediction_regions_org, _ = self.do_prediction_new_concept(True, img_resized, self.models["region"])
|
||||
|
||||
prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h )
|
||||
image_page, page_coord, cont_page = self.extract_page()
|
||||
|
@ -2185,7 +2197,7 @@ class Eynollah:
|
|||
#if self.input_binary:
|
||||
#img_bin = np.copy(img_resized)
|
||||
###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30):
|
||||
###prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
|
||||
###prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5)
|
||||
|
||||
####print("inside bin ", time.time()-t_bin)
|
||||
###prediction_bin=prediction_bin[:,:,0]
|
||||
|
@ -2200,7 +2212,7 @@ class Eynollah:
|
|||
###else:
|
||||
###img_bin = np.copy(img_resized)
|
||||
if (self.ocr and self.tr) and not self.input_binary:
|
||||
prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = self.do_prediction(True, img_resized, self.models["binarization"], n_batch_inference=5)
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0] == 0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
|
||||
prediction_bin = prediction_bin.astype(np.uint16)
|
||||
|
@ -2232,14 +2244,14 @@ class Eynollah:
|
|||
self.logger.debug("resized to %dx%d for %d cols",
|
||||
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
|
||||
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
|
||||
True, img_resized, self.model_region_1_2, n_batch_inference=1,
|
||||
True, img_resized, self.models["region_1_2"], n_batch_inference=1,
|
||||
thresholding_for_some_classes_in_light_version=True,
|
||||
threshold_art_class_layout=self.threshold_art_class_layout)
|
||||
else:
|
||||
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3))
|
||||
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
|
||||
prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept(
|
||||
False, self.image_page_org_size, self.model_region_1_2, n_batch_inference=1,
|
||||
False, self.image_page_org_size, self.models["region_1_2"], n_batch_inference=1,
|
||||
thresholding_for_artificial_class_in_light_version=True,
|
||||
threshold_art_class_layout=self.threshold_art_class_layout)
|
||||
ys = slice(*self.page_coord[0:2])
|
||||
|
@ -2253,10 +2265,10 @@ class Eynollah:
|
|||
self.logger.debug("resized to %dx%d (new_h=%d) for %d cols",
|
||||
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
|
||||
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
|
||||
True, img_resized, self.model_region_1_2, n_batch_inference=2,
|
||||
True, img_resized, self.models["region_1_2"], n_batch_inference=2,
|
||||
thresholding_for_some_classes_in_light_version=True,
|
||||
threshold_art_class_layout=self.threshold_art_class_layout)
|
||||
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_region,
|
||||
###prediction_regions_org = self.do_prediction(True, img_bin, self.models["region"],
|
||||
###n_batch_inference=3,
|
||||
###thresholding_for_some_classes_in_light_version=True)
|
||||
#print("inside 3 ", time.time()-t_in)
|
||||
|
@ -2336,7 +2348,7 @@ class Eynollah:
|
|||
ratio_x=1
|
||||
|
||||
img = resize_image(img_org, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
|
||||
prediction_regions_org_y = self.do_prediction(True, img, self.model_region)
|
||||
prediction_regions_org_y = self.do_prediction(True, img, self.models["region"])
|
||||
prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h )
|
||||
|
||||
#plt.imshow(prediction_regions_org_y[:,:,0])
|
||||
|
@ -2351,7 +2363,7 @@ class Eynollah:
|
|||
_, _ = find_num_col(img_only_regions, num_col_classifier, self.tables, multiplier=6.0)
|
||||
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]*(1.2 if is_image_enhanced else 1)))
|
||||
|
||||
prediction_regions_org = self.do_prediction(True, img, self.model_region)
|
||||
prediction_regions_org = self.do_prediction(True, img, self.models["region"])
|
||||
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
|
||||
|
||||
prediction_regions_org=prediction_regions_org[:,:,0]
|
||||
|
@ -2359,7 +2371,7 @@ class Eynollah:
|
|||
|
||||
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]))
|
||||
|
||||
prediction_regions_org2 = self.do_prediction(True, img, self.model_region_p2, marginal_of_patch_percent=0.2)
|
||||
prediction_regions_org2 = self.do_prediction(True, img, self.models["region_p2"], marginal_of_patch_percent=0.2)
|
||||
prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h )
|
||||
|
||||
mask_zeros2 = (prediction_regions_org2[:,:,0] == 0)
|
||||
|
@ -2383,7 +2395,7 @@ class Eynollah:
|
|||
if self.input_binary:
|
||||
prediction_bin = np.copy(img_org)
|
||||
else:
|
||||
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5)
|
||||
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
|
||||
|
@ -2393,7 +2405,7 @@ class Eynollah:
|
|||
|
||||
img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
|
||||
|
||||
prediction_regions_org = self.do_prediction(True, img, self.model_region)
|
||||
prediction_regions_org = self.do_prediction(True, img, self.models["region"])
|
||||
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
|
||||
prediction_regions_org=prediction_regions_org[:,:,0]
|
||||
|
||||
|
@ -2420,7 +2432,7 @@ class Eynollah:
|
|||
except:
|
||||
if self.input_binary:
|
||||
prediction_bin = np.copy(img_org)
|
||||
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = self.do_prediction(True, img_org, self.models["binarization"], n_batch_inference=5)
|
||||
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
|
||||
|
@ -2431,14 +2443,14 @@ class Eynollah:
|
|||
|
||||
|
||||
img = resize_image(prediction_bin, int(img_org.shape[0]*ratio_y), int(img_org.shape[1]*ratio_x))
|
||||
prediction_regions_org = self.do_prediction(True, img, self.model_region)
|
||||
prediction_regions_org = self.do_prediction(True, img, self.models["region"])
|
||||
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
|
||||
prediction_regions_org=prediction_regions_org[:,:,0]
|
||||
|
||||
#mask_lines_only=(prediction_regions_org[:,:]==3)*1
|
||||
#img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1))
|
||||
|
||||
#prediction_regions_org = self.do_prediction(True, img, self.model_region)
|
||||
#prediction_regions_org = self.do_prediction(True, img, self.models["region"])
|
||||
#prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
|
||||
#prediction_regions_org = prediction_regions_org[:,:,0]
|
||||
#prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0
|
||||
|
@ -2809,13 +2821,13 @@ class Eynollah:
|
|||
img_width_h = img_org.shape[1]
|
||||
patches = False
|
||||
if self.light_version:
|
||||
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_table)
|
||||
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.models["table"])
|
||||
prediction_table = prediction_table.astype(np.int16)
|
||||
return prediction_table[:,:,0]
|
||||
else:
|
||||
if num_col_classifier < 4 and num_col_classifier > 2:
|
||||
prediction_table = self.do_prediction(patches, img, self.model_table)
|
||||
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table)
|
||||
prediction_table = self.do_prediction(patches, img, self.models["table"])
|
||||
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"])
|
||||
pre_updown = cv2.flip(pre_updown, -1)
|
||||
|
||||
prediction_table[:,:,0][pre_updown[:,:,0]==1]=1
|
||||
|
@ -2834,8 +2846,8 @@ class Eynollah:
|
|||
xs = slice(w_start, w_start + img.shape[1])
|
||||
img_new[ys, xs] = img
|
||||
|
||||
prediction_ext = self.do_prediction(patches, img_new, self.model_table)
|
||||
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table)
|
||||
prediction_ext = self.do_prediction(patches, img_new, self.models["table"])
|
||||
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"])
|
||||
pre_updown = cv2.flip(pre_updown, -1)
|
||||
|
||||
prediction_table = prediction_ext[ys, xs]
|
||||
|
@ -2856,8 +2868,8 @@ class Eynollah:
|
|||
xs = slice(w_start, w_start + img.shape[1])
|
||||
img_new[ys, xs] = img
|
||||
|
||||
prediction_ext = self.do_prediction(patches, img_new, self.model_table)
|
||||
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table)
|
||||
prediction_ext = self.do_prediction(patches, img_new, self.models["table"])
|
||||
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.models["table"])
|
||||
pre_updown = cv2.flip(pre_updown, -1)
|
||||
|
||||
prediction_table = prediction_ext[ys, xs]
|
||||
|
@ -2869,10 +2881,10 @@ class Eynollah:
|
|||
prediction_table = np.zeros(img.shape)
|
||||
img_w_half = img.shape[1] // 2
|
||||
|
||||
pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.model_table)
|
||||
pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.model_table)
|
||||
pre_full = self.do_prediction(patches, img[:,:,:], self.model_table)
|
||||
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table)
|
||||
pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.models["table"])
|
||||
pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.models["table"])
|
||||
pre_full = self.do_prediction(patches, img[:,:,:], self.models["table"])
|
||||
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.models["table"])
|
||||
pre_updown = cv2.flip(pre_updown, -1)
|
||||
|
||||
prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4)
|
||||
|
@ -3474,18 +3486,6 @@ class Eynollah:
|
|||
regions_without_separators_d, regions_fully, regions_without_separators,
|
||||
polygons_of_marginals, contours_tables)
|
||||
|
||||
@staticmethod
|
||||
def our_load_model(model_file):
|
||||
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_file = model_file[:-3]
|
||||
try:
|
||||
model = load_model(model_file, compile=False)
|
||||
except:
|
||||
model = load_model(model_file, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
return model
|
||||
|
||||
def do_order_of_regions_with_model(self, contours_only_text_parent, contours_only_text_parent_h, text_regions_p):
|
||||
|
||||
height1 =672#448
|
||||
|
@ -3676,7 +3676,7 @@ class Eynollah:
|
|||
tot_counter += 1
|
||||
batch.append(j)
|
||||
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
||||
y_pr = self.model_reading_order.predict(input_1 , verbose=0)
|
||||
y_pr = self.models["reading_order"].predict(input_1 , verbose=0)
|
||||
for jb, j in enumerate(batch):
|
||||
if y_pr[jb][0]>=0.5:
|
||||
post_list.append(j)
|
||||
|
@ -4259,7 +4259,7 @@ class Eynollah:
|
|||
gc.collect()
|
||||
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons, np.zeros((len(all_found_textline_polygons), 4)),
|
||||
self.prediction_model, self.b_s_ocr, self.num_to_char, textline_light=True)
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, textline_light=True)
|
||||
else:
|
||||
ocr_all_textlines = None
|
||||
|
||||
|
@ -4768,27 +4768,27 @@ class Eynollah:
|
|||
if len(all_found_textline_polygons):
|
||||
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons, all_box_coord,
|
||||
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
|
||||
if len(all_found_textline_polygons_marginals_left):
|
||||
ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons_marginals_left, all_box_coord_marginals_left,
|
||||
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
|
||||
if len(all_found_textline_polygons_marginals_right):
|
||||
ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons_marginals_right, all_box_coord_marginals_right,
|
||||
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
|
||||
if self.full_layout and len(all_found_textline_polygons):
|
||||
ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, all_found_textline_polygons_h, all_box_coord_h,
|
||||
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
|
||||
if self.full_layout and len(polygons_of_drop_capitals):
|
||||
ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines(
|
||||
image_page, polygons_of_drop_capitals, np.zeros((len(polygons_of_drop_capitals), 4)),
|
||||
self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
self.models["ocr"], self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
|
||||
|
||||
else:
|
||||
if self.light_version:
|
||||
|
@ -4800,7 +4800,7 @@ class Eynollah:
|
|||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
self.model_ocr.to(self.device)
|
||||
self.models["ocr"].to(self.device)
|
||||
|
||||
ind_tot = 0
|
||||
#cv2.imwrite('./img_out.png', image_page)
|
||||
|
@ -4837,7 +4837,7 @@ class Eynollah:
|
|||
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
|
||||
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
|
||||
text_ocr = self.return_ocr_of_textline_without_common_section(
|
||||
img_croped, self.model_ocr, self.processor, self.device, w, h2w_ratio, ind_tot)
|
||||
img_croped, self.models["ocr"], self.processor, self.device, w, h2w_ratio, ind_tot)
|
||||
ocr_textline_in_textregion.append(text_ocr)
|
||||
ind_tot = ind_tot +1
|
||||
ocr_all_textlines.append(ocr_textline_in_textregion)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue