mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-06-08 19:59:56 +02:00
resolving merge conflict of machine based reading order and extracting only images branches
This commit is contained in:
commit
f7e5fb917f
10 changed files with 3578 additions and 888 deletions
|
@ -1,8 +0,0 @@
|
|||
# ocrd includes opencv, numpy, shapely, click
|
||||
ocrd >= 2.23.3
|
||||
numpy <1.24.0
|
||||
scikit-learn >= 0.23.2
|
||||
tensorflow == 2.12.1
|
||||
imutils >= 0.5.3
|
||||
matplotlib
|
||||
setuptools >= 50
|
|
@ -2,15 +2,95 @@ import sys
|
|||
import click
|
||||
from ocrd_utils import initLogging, setOverrideLogLevel
|
||||
from eynollah.eynollah import Eynollah
|
||||
from eynollah.eynollah import Eynollah
|
||||
from eynollah.sbb_binarize import SbbBinarizer
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@main.command()
|
||||
@click.option(
|
||||
"--dir_xml",
|
||||
"-dx",
|
||||
help="directory of GT page-xml files",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
|
||||
@click.option(
|
||||
"--dir_out_modal_image",
|
||||
"-domi",
|
||||
help="directory where ground truth images would be written",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
|
||||
@click.option(
|
||||
"--dir_out_classes",
|
||||
"-docl",
|
||||
help="directory where ground truth classes would be written",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
|
||||
@click.option(
|
||||
"--input_height",
|
||||
"-ih",
|
||||
help="input height",
|
||||
)
|
||||
@click.option(
|
||||
"--input_width",
|
||||
"-iw",
|
||||
help="input width",
|
||||
)
|
||||
@click.option(
|
||||
"--min_area_size",
|
||||
"-min",
|
||||
help="min area size of regions considered for reading order training.",
|
||||
)
|
||||
|
||||
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size):
|
||||
xml_files_ind = os.listdir(dir_xml)
|
||||
|
||||
@main.command()
|
||||
@click.option('--patches/--no-patches', default=True, help='by enabling this parameter you let the model to see the image in patches.')
|
||||
|
||||
@click.option('--model_dir', '-m', type=click.Path(exists=True, file_okay=False), required=True, help='directory containing models for prediction')
|
||||
|
||||
@click.argument('input_image')
|
||||
|
||||
@click.argument('output_image')
|
||||
@click.option(
|
||||
"--dir_in",
|
||||
"-di",
|
||||
help="directory of images",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
@click.option(
|
||||
"--dir_out",
|
||||
"-do",
|
||||
help="directory where the binarized images will be written",
|
||||
type=click.Path(exists=True, file_okay=False),
|
||||
)
|
||||
|
||||
def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out):
|
||||
if not dir_out and (dir_in):
|
||||
print("Error: You used -di but did not set -do")
|
||||
sys.exit(1)
|
||||
elif dir_out and not (dir_in):
|
||||
print("Error: You used -do to write out binarized images but have not set -di")
|
||||
sys.exit(1)
|
||||
SbbBinarizer(model_dir).run(image_path=input_image, use_patches=patches, save=output_image, dir_in=dir_in, dir_out=dir_out)
|
||||
|
||||
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"--image",
|
||||
"-i",
|
||||
help="image filename",
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
)
|
||||
|
||||
@click.option(
|
||||
"--out",
|
||||
"-o",
|
||||
|
@ -140,36 +220,41 @@ from eynollah.eynollah import Eynollah
|
|||
help="if this parameter set to true, this tool would ignore page extraction",
|
||||
)
|
||||
@click.option(
|
||||
"--log-level",
|
||||
"--reading_order_machine_based/--heuristic_reading_order",
|
||||
"-romb/-hro",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, this tool would apply machine based reading order detection",
|
||||
)
|
||||
@click.option(
|
||||
"--do_ocr",
|
||||
"-ocr/-noocr",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, this tool will try to do ocr",
|
||||
)
|
||||
@click.option(
|
||||
"--num_col_upper",
|
||||
"-ncu",
|
||||
help="lower limit of columns in document image",
|
||||
)
|
||||
@click.option(
|
||||
"--num_col_lower",
|
||||
"-ncl",
|
||||
help="upper limit of columns in document image",
|
||||
)
|
||||
@click.option(
|
||||
"--skip_layout_and_reading_order",
|
||||
"-slro/-noslro",
|
||||
is_flag=True,
|
||||
help="if this parameter set to true, this tool will ignore layout detection and reading order. It means that textline detection will be done within printspace and contours of textline will be written in xml output file.",
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
"-l",
|
||||
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
|
||||
help="Override log level globally to this",
|
||||
)
|
||||
def main(
|
||||
image,
|
||||
out,
|
||||
dir_in,
|
||||
model,
|
||||
save_images,
|
||||
save_layout,
|
||||
save_deskewed,
|
||||
save_all,
|
||||
extract_only_images,
|
||||
save_page,
|
||||
enable_plotting,
|
||||
allow_enhancement,
|
||||
curved_line,
|
||||
textline_light,
|
||||
full_layout,
|
||||
tables,
|
||||
right2left,
|
||||
input_binary,
|
||||
allow_scaling,
|
||||
headers_off,
|
||||
light_version,
|
||||
ignore_page_extraction,
|
||||
log_level
|
||||
):
|
||||
|
||||
def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, num_col_upper, num_col_lower, skip_layout_and_reading_order, ignore_page_extraction, log_level):
|
||||
if log_level:
|
||||
setOverrideLogLevel(log_level)
|
||||
initLogging()
|
||||
|
@ -182,8 +267,11 @@ def main(
|
|||
if textline_light and not light_version:
|
||||
print('Error: You used -tll to enable light textline detection but -light is not enabled')
|
||||
sys.exit(1)
|
||||
|
||||
if extract_only_images and (allow_enhancement or allow_scaling or light_version or curved_line or textline_light or full_layout or tables or right2left or headers_off) :
|
||||
print('Error: You used -eoi which can not be enabled alongside light_version -light or allow_scaling -as or allow_enhancement -ae or curved_line -cl or textline_light -tll or full_layout -fl or tables -tab or right2left -r2l or headers_off -ho')
|
||||
if light_version and not textline_light:
|
||||
print('Error: You used -light without -tll. Light version need light textline to be enabled.')
|
||||
sys.exit(1)
|
||||
eynollah = Eynollah(
|
||||
image_filename=image,
|
||||
|
@ -208,6 +296,11 @@ def main(
|
|||
headers_off=headers_off,
|
||||
light_version=light_version,
|
||||
ignore_page_extraction=ignore_page_extraction,
|
||||
reading_order_machine_based=reading_order_machine_based,
|
||||
do_ocr=do_ocr,
|
||||
num_col_upper=num_col_upper,
|
||||
num_col_lower=num_col_lower,
|
||||
skip_layout_and_reading_order=skip_layout_and_reading_order,
|
||||
)
|
||||
if dir_in:
|
||||
eynollah.run()
|
||||
|
|
File diff suppressed because it is too large
Load diff
383
src/eynollah/sbb_binarize.py
Normal file
383
src/eynollah/sbb_binarize.py
Normal file
|
@ -0,0 +1,383 @@
|
|||
"""
|
||||
Tool to load model and binarize a given image.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from glob import glob
|
||||
from os import environ, devnull
|
||||
from os.path import join
|
||||
from warnings import catch_warnings, simplefilter
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
stderr = sys.stderr
|
||||
sys.stderr = open(devnull, 'w')
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
sys.stderr = stderr
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
def resize_image(img_in, input_height, input_width):
|
||||
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
class SbbBinarizer:
|
||||
|
||||
def __init__(self, model_dir, logger=None):
|
||||
self.model_dir = model_dir
|
||||
self.log = logger if logger else logging.getLogger('SbbBinarizer')
|
||||
|
||||
self.start_new_session()
|
||||
|
||||
self.model_files = glob(self.model_dir+"/*/", recursive = True)
|
||||
|
||||
self.models = []
|
||||
for model_file in self.model_files:
|
||||
self.models.append(self.load_model(model_file))
|
||||
|
||||
def start_new_session(self):
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(self.session)
|
||||
|
||||
def end_session(self):
|
||||
tensorflow_backend.clear_session()
|
||||
self.session.close()
|
||||
del self.session
|
||||
|
||||
def load_model(self, model_name):
|
||||
model = load_model(join(self.model_dir, model_name), compile=False)
|
||||
model_height = model.layers[len(model.layers)-1].output_shape[1]
|
||||
model_width = model.layers[len(model.layers)-1].output_shape[2]
|
||||
n_classes = model.layers[len(model.layers)-1].output_shape[3]
|
||||
return model, model_height, model_width, n_classes
|
||||
|
||||
def predict(self, model_in, img, use_patches, n_batch_inference=5):
|
||||
tensorflow_backend.set_session(self.session)
|
||||
model, model_height, model_width, n_classes = model_in
|
||||
|
||||
img_org_h = img.shape[0]
|
||||
img_org_w = img.shape[1]
|
||||
|
||||
if img.shape[0] < model_height and img.shape[1] >= model_width:
|
||||
img_padded = np.zeros(( model_height, img.shape[1], img.shape[2] ))
|
||||
|
||||
index_start_h = int( abs( img.shape[0] - model_height) /2.)
|
||||
index_start_w = 0
|
||||
|
||||
img_padded [ index_start_h: index_start_h+img.shape[0], :, : ] = img[:,:,:]
|
||||
|
||||
elif img.shape[0] >= model_height and img.shape[1] < model_width:
|
||||
img_padded = np.zeros(( img.shape[0], model_width, img.shape[2] ))
|
||||
|
||||
index_start_h = 0
|
||||
index_start_w = int( abs( img.shape[1] - model_width) /2.)
|
||||
|
||||
img_padded [ :, index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:]
|
||||
|
||||
|
||||
elif img.shape[0] < model_height and img.shape[1] < model_width:
|
||||
img_padded = np.zeros(( model_height, model_width, img.shape[2] ))
|
||||
|
||||
index_start_h = int( abs( img.shape[0] - model_height) /2.)
|
||||
index_start_w = int( abs( img.shape[1] - model_width) /2.)
|
||||
|
||||
img_padded [ index_start_h: index_start_h+img.shape[0], index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:]
|
||||
|
||||
else:
|
||||
index_start_h = 0
|
||||
index_start_w = 0
|
||||
img_padded = np.copy(img)
|
||||
|
||||
|
||||
img = np.copy(img_padded)
|
||||
|
||||
|
||||
|
||||
if use_patches:
|
||||
|
||||
margin = int(0.1 * model_width)
|
||||
|
||||
width_mid = model_width - 2 * margin
|
||||
height_mid = model_height - 2 * margin
|
||||
|
||||
|
||||
img = img / float(255.0)
|
||||
|
||||
img_h = img.shape[0]
|
||||
img_w = img.shape[1]
|
||||
|
||||
prediction_true = np.zeros((img_h, img_w, 3))
|
||||
mask_true = np.zeros((img_h, img_w))
|
||||
nxf = img_w / float(width_mid)
|
||||
nyf = img_h / float(height_mid)
|
||||
|
||||
if nxf > int(nxf):
|
||||
nxf = int(nxf) + 1
|
||||
else:
|
||||
nxf = int(nxf)
|
||||
|
||||
if nyf > int(nyf):
|
||||
nyf = int(nyf) + 1
|
||||
else:
|
||||
nyf = int(nyf)
|
||||
|
||||
|
||||
list_i_s = []
|
||||
list_j_s = []
|
||||
list_x_u = []
|
||||
list_x_d = []
|
||||
list_y_u = []
|
||||
list_y_d = []
|
||||
|
||||
batch_indexer = 0
|
||||
|
||||
img_patch = np.zeros((n_batch_inference, model_height, model_width,3))
|
||||
|
||||
for i in range(nxf):
|
||||
for j in range(nyf):
|
||||
|
||||
if i == 0:
|
||||
index_x_d = i * width_mid
|
||||
index_x_u = index_x_d + model_width
|
||||
elif i > 0:
|
||||
index_x_d = i * width_mid
|
||||
index_x_u = index_x_d + model_width
|
||||
|
||||
if j == 0:
|
||||
index_y_d = j * height_mid
|
||||
index_y_u = index_y_d + model_height
|
||||
elif j > 0:
|
||||
index_y_d = j * height_mid
|
||||
index_y_u = index_y_d + model_height
|
||||
|
||||
if index_x_u > img_w:
|
||||
index_x_u = img_w
|
||||
index_x_d = img_w - model_width
|
||||
if index_y_u > img_h:
|
||||
index_y_u = img_h
|
||||
index_y_d = img_h - model_height
|
||||
|
||||
|
||||
list_i_s.append(i)
|
||||
list_j_s.append(j)
|
||||
list_x_u.append(index_x_u)
|
||||
list_x_d.append(index_x_d)
|
||||
list_y_d.append(index_y_d)
|
||||
list_y_u.append(index_y_u)
|
||||
|
||||
|
||||
img_patch[batch_indexer,:,:,:] = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
|
||||
batch_indexer = batch_indexer + 1
|
||||
|
||||
|
||||
|
||||
if batch_indexer == n_batch_inference:
|
||||
|
||||
label_p_pred = model.predict(img_patch,verbose=0)
|
||||
|
||||
seg = np.argmax(label_p_pred, axis=3)
|
||||
|
||||
#print(seg.shape, len(seg), len(list_i_s))
|
||||
|
||||
indexer_inside_batch = 0
|
||||
for i_batch, j_batch in zip(list_i_s, list_j_s):
|
||||
seg_in = seg[indexer_inside_batch,:,:]
|
||||
seg_color = np.repeat(seg_in[:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
index_y_u_in = list_y_u[indexer_inside_batch]
|
||||
index_y_d_in = list_y_d[indexer_inside_batch]
|
||||
|
||||
index_x_u_in = list_x_u[indexer_inside_batch]
|
||||
index_x_d_in = list_x_d[indexer_inside_batch]
|
||||
|
||||
if i_batch == 0 and j_batch == 0:
|
||||
seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
|
||||
elif i_batch == nxf - 1 and j_batch == nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
|
||||
elif i_batch == 0 and j_batch == nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
|
||||
elif i_batch == nxf - 1 and j_batch == 0:
|
||||
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
|
||||
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
|
||||
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
|
||||
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
|
||||
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
|
||||
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
|
||||
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
|
||||
else:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
|
||||
|
||||
indexer_inside_batch = indexer_inside_batch +1
|
||||
|
||||
|
||||
list_i_s = []
|
||||
list_j_s = []
|
||||
list_x_u = []
|
||||
list_x_d = []
|
||||
list_y_u = []
|
||||
list_y_d = []
|
||||
|
||||
batch_indexer = 0
|
||||
|
||||
img_patch = np.zeros((n_batch_inference, model_height, model_width,3))
|
||||
|
||||
elif i==(nxf-1) and j==(nyf-1):
|
||||
label_p_pred = model.predict(img_patch,verbose=0)
|
||||
|
||||
seg = np.argmax(label_p_pred, axis=3)
|
||||
|
||||
#print(seg.shape, len(seg), len(list_i_s))
|
||||
|
||||
indexer_inside_batch = 0
|
||||
for i_batch, j_batch in zip(list_i_s, list_j_s):
|
||||
seg_in = seg[indexer_inside_batch,:,:]
|
||||
seg_color = np.repeat(seg_in[:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
index_y_u_in = list_y_u[indexer_inside_batch]
|
||||
index_y_d_in = list_y_d[indexer_inside_batch]
|
||||
|
||||
index_x_u_in = list_x_u[indexer_inside_batch]
|
||||
index_x_d_in = list_x_d[indexer_inside_batch]
|
||||
|
||||
if i_batch == 0 and j_batch == 0:
|
||||
seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
|
||||
elif i_batch == nxf - 1 and j_batch == nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
|
||||
elif i_batch == 0 and j_batch == nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
|
||||
elif i_batch == nxf - 1 and j_batch == 0:
|
||||
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
|
||||
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
|
||||
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
|
||||
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
|
||||
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
|
||||
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
|
||||
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
|
||||
else:
|
||||
seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
|
||||
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
|
||||
|
||||
indexer_inside_batch = indexer_inside_batch +1
|
||||
|
||||
|
||||
list_i_s = []
|
||||
list_j_s = []
|
||||
list_x_u = []
|
||||
list_x_d = []
|
||||
list_y_u = []
|
||||
list_y_d = []
|
||||
|
||||
batch_indexer = 0
|
||||
|
||||
img_patch = np.zeros((n_batch_inference, model_height, model_width,3))
|
||||
|
||||
|
||||
|
||||
prediction_true = prediction_true[index_start_h: index_start_h+img_org_h, index_start_w: index_start_w+img_org_w,:]
|
||||
prediction_true = prediction_true.astype(np.uint8)
|
||||
|
||||
else:
|
||||
img_h_page = img.shape[0]
|
||||
img_w_page = img.shape[1]
|
||||
img = img / float(255.0)
|
||||
img = resize_image(img, model_height, model_width)
|
||||
|
||||
label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
|
||||
|
||||
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
|
||||
prediction_true = prediction_true.astype(np.uint8)
|
||||
return prediction_true[:,:,0]
|
||||
|
||||
def run(self, image=None, image_path=None, save=None, use_patches=False, dir_in=None, dir_out=None):
|
||||
print(dir_in,'dir_in')
|
||||
if not dir_in:
|
||||
if (image is not None and image_path is not None) or \
|
||||
(image is None and image_path is None):
|
||||
raise ValueError("Must pass either a opencv2 image or an image_path")
|
||||
if image_path is not None:
|
||||
image = cv2.imread(image_path)
|
||||
img_last = 0
|
||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
||||
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
|
||||
res[:, :][res[:, :] == 0] = 2
|
||||
res = res - 1
|
||||
res = res * 255
|
||||
img_fin[:, :, 0] = res
|
||||
img_fin[:, :, 1] = res
|
||||
img_fin[:, :, 2] = res
|
||||
|
||||
img_fin = img_fin.astype(np.uint8)
|
||||
img_fin = (res[:, :] == 0) * 255
|
||||
img_last = img_last + img_fin
|
||||
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
img_last[:, :][img_last[:, :] > 0] = 255
|
||||
img_last = (img_last[:, :] == 0) * 255
|
||||
if save:
|
||||
cv2.imwrite(save, img_last)
|
||||
return img_last
|
||||
else:
|
||||
ls_imgs = os.listdir(dir_in)
|
||||
for image_name in ls_imgs:
|
||||
image_stem = image_name.split('.')[0]
|
||||
print(image_name,'image_name')
|
||||
image = cv2.imread(os.path.join(dir_in,image_name) )
|
||||
img_last = 0
|
||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
||||
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
|
||||
res[:, :][res[:, :] == 0] = 2
|
||||
res = res - 1
|
||||
res = res * 255
|
||||
img_fin[:, :, 0] = res
|
||||
img_fin[:, :, 1] = res
|
||||
img_fin[:, :, 2] = res
|
||||
|
||||
img_fin = img_fin.astype(np.uint8)
|
||||
img_fin = (res[:, :] == 0) * 255
|
||||
img_last = img_last + img_fin
|
||||
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
img_last[:, :][img_last[:, :] > 0] = 255
|
||||
img_last = (img_last[:, :] == 0) * 255
|
||||
|
||||
cv2.imwrite(os.path.join(dir_out,image_stem+'.png'), img_last)
|
|
@ -7,7 +7,7 @@ import cv2
|
|||
import imutils
|
||||
from scipy.signal import find_peaks
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
import time
|
||||
from .is_nan import isNaN
|
||||
from .contour import (contours_in_same_horizon,
|
||||
find_new_features_of_contours,
|
||||
|
@ -775,9 +775,8 @@ def put_drop_out_from_only_drop_model(layout_no_patch, layout1):
|
|||
|
||||
return layout_no_patch
|
||||
|
||||
def putt_bb_of_drop_capitals_of_model_in_patches_in_layout(layout_in_patch):
|
||||
|
||||
drop_only = (layout_in_patch[:, :, 0] == 4) * 1
|
||||
def putt_bb_of_drop_capitals_of_model_in_patches_in_layout(layout_in_patch, drop_capital_label):
|
||||
drop_only = (layout_in_patch[:, :, 0] == drop_capital_label) * 1
|
||||
contours_drop, hir_on_drop = return_contours_of_image(drop_only)
|
||||
contours_drop_parent = return_parent_contours(contours_drop, hir_on_drop)
|
||||
|
||||
|
@ -786,13 +785,18 @@ def putt_bb_of_drop_capitals_of_model_in_patches_in_layout(layout_in_patch):
|
|||
|
||||
contours_drop_parent = [contours_drop_parent[jz] for jz in range(len(contours_drop_parent)) if areas_cnt_text[jz] > 0.00001]
|
||||
|
||||
areas_cnt_text = [areas_cnt_text[jz] for jz in range(len(areas_cnt_text)) if areas_cnt_text[jz] > 0.001]
|
||||
areas_cnt_text = [areas_cnt_text[jz] for jz in range(len(areas_cnt_text)) if areas_cnt_text[jz] > 0.00001]
|
||||
|
||||
contours_drop_parent_final = []
|
||||
|
||||
for jj in range(len(contours_drop_parent)):
|
||||
x, y, w, h = cv2.boundingRect(contours_drop_parent[jj])
|
||||
layout_in_patch[y : y + h, x : x + w, 0] = 4
|
||||
|
||||
if ( ( areas_cnt_text[jj] * float(drop_only.shape[0] * drop_only.shape[1]) ) / float(w*h) ) > 0.4:
|
||||
|
||||
layout_in_patch[y : y + h, x : x + w, 0] = drop_capital_label
|
||||
else:
|
||||
layout_in_patch[y : y + h, x : x + w, 0][layout_in_patch[y : y + h, x : x + w, 0] == drop_capital_label] = 1#drop_capital_label
|
||||
|
||||
return layout_in_patch
|
||||
|
||||
|
@ -1200,17 +1204,12 @@ def order_of_regions(textline_mask, contours_main, contours_header, y_ref):
|
|||
top = peaks_neg_new[i]
|
||||
down = peaks_neg_new[i + 1]
|
||||
|
||||
# print(top,down,'topdown')
|
||||
|
||||
indexes_in = matrix_of_orders[:, 0][(matrix_of_orders[:, 3] >= top) & ((matrix_of_orders[:, 3] < down))]
|
||||
cxs_in = matrix_of_orders[:, 2][(matrix_of_orders[:, 3] >= top) & ((matrix_of_orders[:, 3] < down))]
|
||||
cys_in = matrix_of_orders[:, 3][(matrix_of_orders[:, 3] >= top) & ((matrix_of_orders[:, 3] < down))]
|
||||
types_of_text = matrix_of_orders[:, 1][(matrix_of_orders[:, 3] >= top) & ((matrix_of_orders[:, 3] < down))]
|
||||
index_types_of_text = matrix_of_orders[:, 4][(matrix_of_orders[:, 3] >= top) & ((matrix_of_orders[:, 3] < down))]
|
||||
|
||||
# print(top,down)
|
||||
# print(cys_in,'cyyyins')
|
||||
# print(indexes_in,'indexes')
|
||||
sorted_inside = np.argsort(cxs_in)
|
||||
|
||||
ind_in_int = indexes_in[sorted_inside]
|
||||
|
@ -1224,11 +1223,17 @@ def order_of_regions(textline_mask, contours_main, contours_header, y_ref):
|
|||
|
||||
##matrix_of_orders[:len_main,4]=final_indexers_sorted[:]
|
||||
|
||||
# print(peaks_neg_new,'peaks')
|
||||
# print(final_indexers_sorted,'indexsorted')
|
||||
# print(final_types,'types')
|
||||
# print(final_index_type,'final_index_type')
|
||||
|
||||
# This fix is applied if the sum of the lengths of contours and contours_h does not match final_indexers_sorted. However, this is not the optimal solution..
|
||||
if (len(cy_main)+len(cy_header) ) == len(final_index_type):
|
||||
pass
|
||||
else:
|
||||
indexes_missed = set(list( np.array( range((len(cy_main)+len(cy_header) ) )) )) - set(final_indexers_sorted)
|
||||
for ind_missed in indexes_missed:
|
||||
final_indexers_sorted.append(ind_missed)
|
||||
final_types.append(1)
|
||||
final_index_type.append(ind_missed)
|
||||
|
||||
|
||||
return final_indexers_sorted, matrix_of_orders, final_types, final_index_type
|
||||
|
||||
def combine_hor_lines_and_delete_cross_points_and_get_lines_features_back_new(img_p_in_ver, img_in_hor,num_col_classifier):
|
||||
|
@ -1338,7 +1343,7 @@ def return_points_with_boundies(peaks_neg_fin, first_point, last_point):
|
|||
return peaks_neg_tot
|
||||
|
||||
def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables, pixel_lines, contours_h=None):
|
||||
|
||||
t_ins_c0 = time.time()
|
||||
separators_closeup=( (region_pre_p[:,:,:]==pixel_lines))*1
|
||||
|
||||
separators_closeup[0:110,:,:]=0
|
||||
|
@ -1352,84 +1357,47 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
|
||||
|
||||
separators_closeup_new=np.zeros((separators_closeup.shape[0] ,separators_closeup.shape[1] ))
|
||||
|
||||
|
||||
|
||||
##_,separators_closeup_n=self.combine_hor_lines_and_delete_cross_points_and_get_lines_features_back(region_pre_p[:,:,0])
|
||||
separators_closeup_n=np.copy(separators_closeup)
|
||||
|
||||
separators_closeup_n=separators_closeup_n.astype(np.uint8)
|
||||
##plt.imshow(separators_closeup_n[:,:,0])
|
||||
##plt.show()
|
||||
|
||||
separators_closeup_n_binary=np.zeros(( separators_closeup_n.shape[0],separators_closeup_n.shape[1]) )
|
||||
separators_closeup_n_binary[:,:]=separators_closeup_n[:,:,0]
|
||||
|
||||
separators_closeup_n_binary[:,:][separators_closeup_n_binary[:,:]!=0]=1
|
||||
#separators_closeup_n_binary[:,:][separators_closeup_n_binary[:,:]==0]=255
|
||||
#separators_closeup_n_binary[:,:][separators_closeup_n_binary[:,:]==-255]=0
|
||||
|
||||
|
||||
#separators_closeup_n_binary=(separators_closeup_n_binary[:,:]==2)*1
|
||||
|
||||
#gray = cv2.cvtColor(separators_closeup_n, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
###
|
||||
|
||||
#print(separators_closeup_n_binary.shape)
|
||||
|
||||
gray_early=np.repeat(separators_closeup_n_binary[:, :, np.newaxis], 3, axis=2)
|
||||
gray_early=gray_early.astype(np.uint8)
|
||||
|
||||
#print(gray_early.shape,'burda')
|
||||
imgray_e = cv2.cvtColor(gray_early, cv2.COLOR_BGR2GRAY)
|
||||
#print('burda2')
|
||||
ret_e, thresh_e = cv2.threshold(imgray_e, 0, 255, 0)
|
||||
|
||||
#print('burda3')
|
||||
contours_line_e,hierarchy_e=cv2.findContours(thresh_e,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
#slope_lines_e,dist_x_e, x_min_main_e ,x_max_main_e ,cy_main_e,slope_lines_org_e,y_min_main_e, y_max_main_e, cx_main_e=self.find_features_of_lines(contours_line_e)
|
||||
|
||||
slope_linese,dist_xe, x_min_maine ,x_max_maine ,cy_maine,slope_lines_orge,y_min_maine, y_max_maine, cx_maine=find_features_of_lines(contours_line_e)
|
||||
|
||||
dist_ye=y_max_maine-y_min_maine
|
||||
#print(y_max_maine-y_min_maine,'y')
|
||||
#print(dist_xe,'x')
|
||||
|
||||
|
||||
args_e=np.array(range(len(contours_line_e)))
|
||||
args_hor_e=args_e[(dist_ye<=50) & (dist_xe>=3*dist_ye)]
|
||||
|
||||
#print(args_hor_e,'jidi',len(args_hor_e),'jilva')
|
||||
|
||||
cnts_hor_e=[]
|
||||
for ce in args_hor_e:
|
||||
cnts_hor_e.append(contours_line_e[ce])
|
||||
#print(len(slope_linese),'lieee')
|
||||
|
||||
figs_e=np.zeros(thresh_e.shape)
|
||||
figs_e=cv2.fillPoly(figs_e,pts=cnts_hor_e,color=(1,1,1))
|
||||
|
||||
#plt.imshow(figs_e)
|
||||
#plt.show()
|
||||
|
||||
###
|
||||
|
||||
separators_closeup_n_binary=cv2.fillPoly(separators_closeup_n_binary,pts=cnts_hor_e,color=(0,0,0))
|
||||
|
||||
gray = cv2.bitwise_not(separators_closeup_n_binary)
|
||||
gray=gray.astype(np.uint8)
|
||||
|
||||
|
||||
#plt.imshow(gray)
|
||||
#plt.show()
|
||||
|
||||
|
||||
bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, \
|
||||
cv2.THRESH_BINARY, 15, -2)
|
||||
##plt.imshow(bw[:,:])
|
||||
##plt.show()
|
||||
|
||||
|
||||
horizontal = np.copy(bw)
|
||||
vertical = np.copy(bw)
|
||||
|
||||
|
@ -1447,16 +1415,7 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
horizontal = cv2.dilate(horizontal,kernel,iterations = 2)
|
||||
horizontal = cv2.erode(horizontal,kernel,iterations = 2)
|
||||
|
||||
|
||||
###
|
||||
#print(np.unique(horizontal),'uni')
|
||||
horizontal=cv2.fillPoly(horizontal,pts=cnts_hor_e,color=(255,255,255))
|
||||
###
|
||||
|
||||
|
||||
|
||||
#plt.imshow(horizontal)
|
||||
#plt.show()
|
||||
|
||||
rows = vertical.shape[0]
|
||||
verticalsize = rows // 30
|
||||
|
@ -1467,35 +1426,21 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
vertical = cv2.dilate(vertical, verticalStructure)
|
||||
|
||||
vertical = cv2.dilate(vertical,kernel,iterations = 1)
|
||||
# Show extracted vertical lines
|
||||
|
||||
horizontal,special_separators=combine_hor_lines_and_delete_cross_points_and_get_lines_features_back_new(vertical,horizontal,num_col_classifier)
|
||||
|
||||
|
||||
#plt.imshow(horizontal)
|
||||
#plt.show()
|
||||
#print(vertical.shape,np.unique(vertical),'verticalvertical')
|
||||
separators_closeup_new[:,:][vertical[:,:]!=0]=1
|
||||
separators_closeup_new[:,:][horizontal[:,:]!=0]=1
|
||||
|
||||
##plt.imshow(separators_closeup_new)
|
||||
##plt.show()
|
||||
##separators_closeup_n
|
||||
vertical=np.repeat(vertical[:, :, np.newaxis], 3, axis=2)
|
||||
vertical=vertical.astype(np.uint8)
|
||||
|
||||
##plt.plot(vertical[:,:,0].sum(axis=0))
|
||||
##plt.show()
|
||||
|
||||
#plt.plot(vertical[:,:,0].sum(axis=1))
|
||||
#plt.show()
|
||||
|
||||
imgray = cv2.cvtColor(vertical, cv2.COLOR_BGR2GRAY)
|
||||
ret, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
|
||||
contours_line_vers,hierarchy=cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
|
||||
slope_lines,dist_x, x_min_main ,x_max_main ,cy_main,slope_lines_org,y_min_main, y_max_main, cx_main=find_features_of_lines(contours_line_vers)
|
||||
#print(slope_lines,'vertical')
|
||||
|
||||
args=np.array( range(len(slope_lines) ))
|
||||
args_ver=args[slope_lines==1]
|
||||
dist_x_ver=dist_x[slope_lines==1]
|
||||
|
@ -1508,9 +1453,6 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
len_y=separators_closeup.shape[0]/3.0
|
||||
|
||||
|
||||
#plt.imshow(horizontal)
|
||||
#plt.show()
|
||||
|
||||
horizontal=np.repeat(horizontal[:, :, np.newaxis], 3, axis=2)
|
||||
horizontal=horizontal.astype(np.uint8)
|
||||
imgray = cv2.cvtColor(horizontal, cv2.COLOR_BGR2GRAY)
|
||||
|
@ -1578,8 +1520,6 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
|
||||
matrix_of_lines_ch[len(cy_main_hor):,9]=1
|
||||
|
||||
|
||||
|
||||
if contours_h is not None:
|
||||
slope_lines_head,dist_x_head, x_min_main_head ,x_max_main_head ,cy_main_head,slope_lines_org_head,y_min_main_head, y_max_main_head, cx_main_head=find_features_of_lines(contours_h)
|
||||
matrix_l_n=np.zeros((matrix_of_lines_ch.shape[0]+len(cy_main_head),matrix_of_lines_ch.shape[1]))
|
||||
|
@ -1625,8 +1565,6 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
|
||||
args_big_parts=np.array(range(len(splitter_y_new_diff))) [ splitter_y_new_diff>22 ]
|
||||
|
||||
|
||||
|
||||
regions_without_separators=return_regions_without_separators(region_pre_p)
|
||||
|
||||
|
||||
|
@ -1636,19 +1574,8 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
peaks_neg_fin_fin=[]
|
||||
|
||||
for itiles in args_big_parts:
|
||||
|
||||
|
||||
regions_without_separators_tile=regions_without_separators[int(splitter_y_new[itiles]):int(splitter_y_new[itiles+1]),:,0]
|
||||
#image_page_background_zero_tile=image_page_background_zero[int(splitter_y_new[itiles]):int(splitter_y_new[itiles+1]),:]
|
||||
|
||||
#print(regions_without_separators_tile.shape)
|
||||
##plt.imshow(regions_without_separators_tile)
|
||||
##plt.show()
|
||||
|
||||
#num_col, peaks_neg_fin=self.find_num_col(regions_without_separators_tile,multiplier=6.0)
|
||||
|
||||
#regions_without_separators_tile=cv2.erode(regions_without_separators_tile,kernel,iterations = 3)
|
||||
#
|
||||
|
||||
try:
|
||||
num_col, peaks_neg_fin = find_num_col(regions_without_separators_tile, num_col_classifier, tables, multiplier=7.0)
|
||||
except:
|
||||
|
@ -1666,9 +1593,6 @@ def find_number_of_columns_in_document(region_pre_p, num_col_classifier, tables,
|
|||
peaks_neg_fin=peaks_neg_fin[peaks_neg_fin<=(vertical.shape[1]-500)]
|
||||
peaks_neg_fin_fin=peaks_neg_fin[:]
|
||||
|
||||
#print(peaks_neg_fin_fin,'peaks_neg_fin_fintaza')
|
||||
|
||||
|
||||
return num_col_fin, peaks_neg_fin_fin,matrix_of_lines_ch,splitter_y_new,separators_closeup_n
|
||||
|
||||
|
||||
|
|
|
@ -263,7 +263,7 @@ def get_textregion_contours_in_org_image(cnts, img, slope_first):
|
|||
|
||||
return cnts_org
|
||||
|
||||
def get_textregion_contours_in_org_image_light(cnts, img, slope_first):
|
||||
def get_textregion_contours_in_org_image_light_old(cnts, img, slope_first):
|
||||
|
||||
h_o = img.shape[0]
|
||||
w_o = img.shape[1]
|
||||
|
@ -278,14 +278,7 @@ def get_textregion_contours_in_org_image_light(cnts, img, slope_first):
|
|||
img_copy = np.zeros(img.shape)
|
||||
img_copy = cv2.fillPoly(img_copy, pts=[cnts[i]], color=(1, 1, 1))
|
||||
|
||||
# plt.imshow(img_copy)
|
||||
# plt.show()
|
||||
|
||||
# print(img.shape,'img')
|
||||
img_copy = rotation_image_new(img_copy, -slope_first)
|
||||
##print(img_copy.shape,'img_copy')
|
||||
# plt.imshow(img_copy)
|
||||
# plt.show()
|
||||
|
||||
img_copy = img_copy.astype(np.uint8)
|
||||
imgray = cv2.cvtColor(img_copy, cv2.COLOR_BGR2GRAY)
|
||||
|
@ -300,6 +293,70 @@ def get_textregion_contours_in_org_image_light(cnts, img, slope_first):
|
|||
|
||||
return cnts_org
|
||||
|
||||
def return_list_of_contours_with_desired_order(ls_cons, sorted_indexes):
|
||||
return [ls_cons[sorted_indexes[index]] for index in range(len(sorted_indexes))]
|
||||
def do_back_rotation_and_get_cnt_back(queue_of_all_params, contours_par_per_process,indexes_r_con_per_pro, img, slope_first):
|
||||
contours_textregion_per_each_subprocess = []
|
||||
index_by_text_region_contours = []
|
||||
for mv in range(len(contours_par_per_process)):
|
||||
img_copy = np.zeros(img.shape)
|
||||
img_copy = cv2.fillPoly(img_copy, pts=[contours_par_per_process[mv]], color=(1, 1, 1))
|
||||
|
||||
img_copy = rotation_image_new(img_copy, -slope_first)
|
||||
|
||||
img_copy = img_copy.astype(np.uint8)
|
||||
imgray = cv2.cvtColor(img_copy, cv2.COLOR_BGR2GRAY)
|
||||
ret, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
|
||||
cont_int, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
cont_int[0][:, 0, 0] = cont_int[0][:, 0, 0] + np.abs(img_copy.shape[1] - img.shape[1])
|
||||
cont_int[0][:, 0, 1] = cont_int[0][:, 0, 1] + np.abs(img_copy.shape[0] - img.shape[0])
|
||||
# print(np.shape(cont_int[0]))
|
||||
contours_textregion_per_each_subprocess.append(cont_int[0]*6)
|
||||
index_by_text_region_contours.append(indexes_r_con_per_pro[mv])
|
||||
|
||||
queue_of_all_params.put([contours_textregion_per_each_subprocess, index_by_text_region_contours])
|
||||
|
||||
def get_textregion_contours_in_org_image_light(cnts, img, slope_first):
|
||||
num_cores = cpu_count()
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(cnts), num_cores + 1)
|
||||
indexes_by_text_con = np.array(range(len(cnts)))
|
||||
|
||||
h_o = img.shape[0]
|
||||
w_o = img.shape[1]
|
||||
|
||||
img = cv2.resize(img, (int(img.shape[1]/6.), int(img.shape[0]/6.)), interpolation=cv2.INTER_NEAREST)
|
||||
##cnts = list( (np.array(cnts)/2).astype(np.int16) )
|
||||
#cnts = cnts/2
|
||||
cnts = [(i/ 6).astype(np.int32) for i in cnts]
|
||||
|
||||
for i in range(num_cores):
|
||||
contours_par_per_process = cnts[int(nh[i]) : int(nh[i + 1])]
|
||||
indexes_text_con_per_process = indexes_by_text_con[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_back_rotation_and_get_cnt_back, args=(queue_of_all_params, contours_par_per_process, indexes_text_con_per_process, img, slope_first)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
cnts_org = []
|
||||
all_index_text_con = []
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
contours_for_subprocess = list_all_par[0]
|
||||
indexes_for_subprocess = list_all_par[1]
|
||||
for j in range(len(contours_for_subprocess)):
|
||||
cnts_org.append(contours_for_subprocess[j])
|
||||
all_index_text_con.append(indexes_for_subprocess[j])
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
|
||||
cnts_org = return_list_of_contours_with_desired_order(cnts_org, all_index_text_con)
|
||||
|
||||
return cnts_org
|
||||
|
||||
def return_contours_of_interested_textline(region_pre_p, pixel):
|
||||
|
||||
# pixels of images are identified by 5
|
||||
|
|
|
@ -8,7 +8,7 @@ from .contour import find_new_features_of_contours, return_contours_of_intereste
|
|||
from .resize import resize_image
|
||||
from .rotate import rotate_image
|
||||
|
||||
def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=None):
|
||||
def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, light_version=False, kernel=None):
|
||||
mask_marginals=np.zeros((text_with_lines.shape[0],text_with_lines.shape[1]))
|
||||
mask_marginals=mask_marginals.astype(np.uint8)
|
||||
|
||||
|
@ -49,27 +49,14 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=N
|
|||
if thickness_along_y_percent>=14:
|
||||
|
||||
text_with_lines_y_rev=-1*text_with_lines_y[:]
|
||||
#print(text_with_lines_y)
|
||||
#print(text_with_lines_y_rev)
|
||||
|
||||
|
||||
|
||||
|
||||
#plt.plot(text_with_lines_y)
|
||||
#plt.show()
|
||||
|
||||
|
||||
text_with_lines_y_rev=text_with_lines_y_rev-np.min(text_with_lines_y_rev)
|
||||
|
||||
#plt.plot(text_with_lines_y_rev)
|
||||
#plt.show()
|
||||
sigma_gaus=1
|
||||
region_sum_0= gaussian_filter1d(text_with_lines_y, sigma_gaus)
|
||||
|
||||
region_sum_0_rev=gaussian_filter1d(text_with_lines_y_rev, sigma_gaus)
|
||||
|
||||
#plt.plot(region_sum_0_rev)
|
||||
#plt.show()
|
||||
region_sum_0_updown=region_sum_0[len(region_sum_0)::-1]
|
||||
|
||||
first_nonzero=(next((i for i, x in enumerate(region_sum_0) if x), None))
|
||||
|
@ -78,43 +65,17 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=N
|
|||
|
||||
last_nonzero=len(region_sum_0)-last_nonzero
|
||||
|
||||
##img_sum_0_smooth_rev=-region_sum_0
|
||||
|
||||
|
||||
mid_point=(last_nonzero+first_nonzero)/2.
|
||||
|
||||
|
||||
one_third_right=(last_nonzero-mid_point)/3.0
|
||||
one_third_left=(mid_point-first_nonzero)/3.0
|
||||
|
||||
#img_sum_0_smooth_rev=img_sum_0_smooth_rev-np.min(img_sum_0_smooth_rev)
|
||||
|
||||
|
||||
|
||||
|
||||
peaks, _ = find_peaks(text_with_lines_y_rev, height=0)
|
||||
|
||||
|
||||
peaks=np.array(peaks)
|
||||
|
||||
|
||||
#print(region_sum_0[peaks])
|
||||
##plt.plot(region_sum_0)
|
||||
##plt.plot(peaks,region_sum_0[peaks],'*')
|
||||
##plt.show()
|
||||
#print(first_nonzero,last_nonzero,peaks)
|
||||
peaks=peaks[(peaks>first_nonzero) & ((peaks<last_nonzero))]
|
||||
|
||||
#print(first_nonzero,last_nonzero,peaks)
|
||||
|
||||
|
||||
#print(region_sum_0[peaks]<10)
|
||||
####peaks=peaks[region_sum_0[peaks]<25 ]
|
||||
|
||||
#print(region_sum_0[peaks])
|
||||
peaks=peaks[region_sum_0[peaks]<min_textline_thickness ]
|
||||
#print(peaks)
|
||||
#print(first_nonzero,last_nonzero,one_third_right,one_third_left)
|
||||
|
||||
|
||||
if num_col==1:
|
||||
peaks_right=peaks[peaks>mid_point]
|
||||
|
@ -137,9 +98,6 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=N
|
|||
|
||||
|
||||
|
||||
|
||||
#print(point_left,point_right)
|
||||
#print(text_regions.shape)
|
||||
if point_right>=mask_marginals.shape[1]:
|
||||
point_right=mask_marginals.shape[1]-1
|
||||
|
||||
|
@ -148,10 +106,8 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=N
|
|||
except:
|
||||
mask_marginals[:,:]=1
|
||||
|
||||
#print(mask_marginals.shape,point_left,point_right,'nadosh')
|
||||
mask_marginals_rotated=rotate_image(mask_marginals,-slope_deskew)
|
||||
|
||||
#print(mask_marginals_rotated.shape,'nadosh')
|
||||
mask_marginals_rotated_sum=mask_marginals_rotated.sum(axis=0)
|
||||
|
||||
mask_marginals_rotated_sum[mask_marginals_rotated_sum!=0]=1
|
||||
|
@ -168,11 +124,6 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=N
|
|||
max_point_of_right_marginal=text_regions.shape[1]-1
|
||||
|
||||
|
||||
#print(np.min(index_x_interest) ,np.max(index_x_interest),'minmaxnew')
|
||||
#print(mask_marginals_rotated.shape,text_regions.shape,'mask_marginals_rotated')
|
||||
#plt.imshow(mask_marginals)
|
||||
#plt.show()
|
||||
|
||||
#plt.imshow(mask_marginals_rotated)
|
||||
#plt.show()
|
||||
|
||||
|
@ -195,10 +146,9 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=N
|
|||
x_min_marginals_right=[]
|
||||
|
||||
for i in range(len(cx_text_only)):
|
||||
|
||||
x_width_mar=abs(x_min_text_only[i]-x_max_text_only[i])
|
||||
y_height_mar=abs(y_min_text_only[i]-y_max_text_only[i])
|
||||
#print(x_width_mar,y_height_mar,y_height_mar/x_width_mar,'y_height_mar')
|
||||
|
||||
if x_width_mar>16 and y_height_mar/x_width_mar<18:
|
||||
marginlas_should_be_main_text.append(polygons_of_marginals[i])
|
||||
if x_min_text_only[i]<(mid_point-one_third_left):
|
||||
|
@ -220,18 +170,13 @@ def get_marginals(text_with_lines, text_regions, num_col, slope_deskew, kernel=N
|
|||
x_min_marginals_right=[text_regions.shape[1]-1]
|
||||
|
||||
|
||||
|
||||
|
||||
#print(x_min_marginals_left[0],x_min_marginals_right[0],'margo')
|
||||
|
||||
#print(marginlas_should_be_main_text,'marginlas_should_be_main_text')
|
||||
text_regions=cv2.fillPoly(text_regions, pts =marginlas_should_be_main_text, color=(4,4))
|
||||
|
||||
#print(np.unique(text_regions))
|
||||
|
||||
#text_regions[:,:int(x_min_marginals_left[0])][text_regions[:,:int(x_min_marginals_left[0])]==1]=0
|
||||
#text_regions[:,int(x_min_marginals_right[0]):][text_regions[:,int(x_min_marginals_right[0]):]==1]=0
|
||||
|
||||
|
||||
|
||||
text_regions[:,:int(min_point_of_left_marginal)][text_regions[:,:int(min_point_of_left_marginal)]==1]=0
|
||||
text_regions[:,int(max_point_of_right_marginal):][text_regions[:,int(max_point_of_right_marginal):]==1]=0
|
||||
|
||||
|
|
|
@ -3,7 +3,8 @@ import cv2
|
|||
from scipy.signal import find_peaks
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
import os
|
||||
|
||||
from multiprocessing import Process, Queue, cpu_count
|
||||
from multiprocessing import Pool
|
||||
from .rotate import rotate_image
|
||||
from .contour import (
|
||||
return_parent_contours,
|
||||
|
@ -1569,8 +1570,21 @@ def separate_lines_new2(img_path, thetha, num_col, slope_region, plotter=None):
|
|||
# plt.show()
|
||||
return img_patch_ineterst_revised
|
||||
|
||||
def return_deskew_slop(img_patch_org, sigma_des, main_page=False, plotter=None):
|
||||
def do_image_rotation(queue_of_all_params,angels_per_process, img_resized, sigma_des):
|
||||
angels_per_each_subprocess = []
|
||||
for mv in range(len(angels_per_process)):
|
||||
img_rot=rotate_image(img_resized,angels_per_process[mv])
|
||||
img_rot[img_rot!=0]=1
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
except:
|
||||
var_spectrum=0
|
||||
angels_per_each_subprocess.append(var_spectrum)
|
||||
|
||||
queue_of_all_params.put([angels_per_each_subprocess])
|
||||
|
||||
def return_deskew_slop(img_patch_org, sigma_des,n_tot_angles=100, main_page=False, plotter=None):
|
||||
num_cores = cpu_count()
|
||||
if main_page and plotter:
|
||||
plotter.save_plot_of_textline_density(img_patch_org)
|
||||
|
||||
|
@ -1603,22 +1617,44 @@ def return_deskew_slop(img_patch_org, sigma_des, main_page=False, plotter=None):
|
|||
#plt.imshow(img_resized)
|
||||
#plt.show()
|
||||
angels=np.array([-45, 0 , 45 , 90 , ])#np.linspace(-12,12,100)#np.array([0 , 45 , 90 , -45])
|
||||
|
||||
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(angels), num_cores + 1)
|
||||
|
||||
for i in range(num_cores):
|
||||
angels_per_process = angels[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_image_rotation, args=(queue_of_all_params, angels_per_process, img_resized, sigma_des)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
var_res=[]
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
angles_for_subprocess = list_all_par[0]
|
||||
for j in range(len(angles_for_subprocess)):
|
||||
var_res.append(angles_for_subprocess[j])
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
|
||||
for rot in angels:
|
||||
img_rot=rotate_image(img_resized,rot)
|
||||
#plt.imshow(img_rot)
|
||||
#plt.show()
|
||||
img_rot[img_rot!=0]=1
|
||||
#neg_peaks,var_spectrum=self.find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
#print(var_spectrum,'var_spectrum')
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
##print(rot,var_spectrum,'var_spectrum')
|
||||
except:
|
||||
var_spectrum=0
|
||||
var_res.append(var_spectrum)
|
||||
###for rot in angels:
|
||||
###img_rot=rotate_image(img_resized,rot)
|
||||
####plt.imshow(img_rot)
|
||||
####plt.show()
|
||||
###img_rot[img_rot!=0]=1
|
||||
####neg_peaks,var_spectrum=self.find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
####print(var_spectrum,'var_spectrum')
|
||||
###try:
|
||||
###var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
#####print(rot,var_spectrum,'var_spectrum')
|
||||
###except:
|
||||
###var_spectrum=0
|
||||
###var_res.append(var_spectrum)
|
||||
|
||||
|
||||
|
||||
try:
|
||||
var_res=np.array(var_res)
|
||||
ang_int=angels[np.argmax(var_res)]#angels_sorted[arg_final]#angels[arg_sort_early[arg_sort[arg_final]]]#angels[arg_fin]
|
||||
|
@ -1626,19 +1662,40 @@ def return_deskew_slop(img_patch_org, sigma_des, main_page=False, plotter=None):
|
|||
ang_int=0
|
||||
|
||||
|
||||
angels=np.linspace(ang_int-22.5,ang_int+22.5,100)
|
||||
angels=np.linspace(ang_int-22.5,ang_int+22.5,n_tot_angles)
|
||||
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(angels), num_cores + 1)
|
||||
|
||||
for i in range(num_cores):
|
||||
angels_per_process = angels[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_image_rotation, args=(queue_of_all_params, angels_per_process, img_resized, sigma_des)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
var_res=[]
|
||||
for rot in angels:
|
||||
img_rot=rotate_image(img_resized,rot)
|
||||
##plt.imshow(img_rot)
|
||||
##plt.show()
|
||||
img_rot[img_rot!=0]=1
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
except:
|
||||
var_spectrum=0
|
||||
var_res.append(var_spectrum)
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
angles_for_subprocess = list_all_par[0]
|
||||
for j in range(len(angles_for_subprocess)):
|
||||
var_res.append(angles_for_subprocess[j])
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
|
||||
##var_res=[]
|
||||
##for rot in angels:
|
||||
##img_rot=rotate_image(img_resized,rot)
|
||||
####plt.imshow(img_rot)
|
||||
####plt.show()
|
||||
##img_rot[img_rot!=0]=1
|
||||
##try:
|
||||
##var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
##except:
|
||||
##var_spectrum=0
|
||||
##var_res.append(var_spectrum)
|
||||
try:
|
||||
var_res=np.array(var_res)
|
||||
ang_int=angels[np.argmax(var_res)]#angels_sorted[arg_final]#angels[arg_sort_early[arg_sort[arg_final]]]#angels[arg_fin]
|
||||
|
@ -1649,25 +1706,47 @@ def return_deskew_slop(img_patch_org, sigma_des, main_page=False, plotter=None):
|
|||
|
||||
#plt.imshow(img_resized)
|
||||
#plt.show()
|
||||
angels=np.linspace(-12,12,100)#np.array([0 , 45 , 90 , -45])
|
||||
|
||||
|
||||
angels=np.linspace(-12,12,n_tot_angles)#np.array([0 , 45 , 90 , -45])
|
||||
|
||||
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(angels), num_cores + 1)
|
||||
|
||||
for i in range(num_cores):
|
||||
angels_per_process = angels[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_image_rotation, args=(queue_of_all_params, angels_per_process, img_resized, sigma_des)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
var_res=[]
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
angles_for_subprocess = list_all_par[0]
|
||||
for j in range(len(angles_for_subprocess)):
|
||||
var_res.append(angles_for_subprocess[j])
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
|
||||
for rot in angels:
|
||||
img_rot=rotate_image(img_resized,rot)
|
||||
#plt.imshow(img_rot)
|
||||
#plt.show()
|
||||
img_rot[img_rot!=0]=1
|
||||
#neg_peaks,var_spectrum=self.find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
#print(var_spectrum,'var_spectrum')
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
|
||||
except:
|
||||
var_spectrum=0
|
||||
##var_res=[]
|
||||
|
||||
var_res.append(var_spectrum)
|
||||
##for rot in angels:
|
||||
##img_rot=rotate_image(img_resized,rot)
|
||||
###plt.imshow(img_rot)
|
||||
###plt.show()
|
||||
##img_rot[img_rot!=0]=1
|
||||
###neg_peaks,var_spectrum=self.find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
###print(var_spectrum,'var_spectrum')
|
||||
##try:
|
||||
##var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
|
||||
##except:
|
||||
##var_spectrum=0
|
||||
|
||||
##var_res.append(var_spectrum)
|
||||
|
||||
|
||||
if plotter:
|
||||
|
@ -1680,18 +1759,39 @@ def return_deskew_slop(img_patch_org, sigma_des, main_page=False, plotter=None):
|
|||
|
||||
early_slope_edge=11
|
||||
if abs(ang_int)>early_slope_edge and ang_int<0:
|
||||
angels=np.linspace(-90,-12,100)
|
||||
angels=np.linspace(-90,-12,n_tot_angles)
|
||||
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(angels), num_cores + 1)
|
||||
|
||||
for i in range(num_cores):
|
||||
angels_per_process = angels[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_image_rotation, args=(queue_of_all_params, angels_per_process, img_resized, sigma_des)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
var_res=[]
|
||||
for rot in angels:
|
||||
img_rot=rotate_image(img_resized,rot)
|
||||
##plt.imshow(img_rot)
|
||||
##plt.show()
|
||||
img_rot[img_rot!=0]=1
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
except:
|
||||
var_spectrum=0
|
||||
var_res.append(var_spectrum)
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
angles_for_subprocess = list_all_par[0]
|
||||
for j in range(len(angles_for_subprocess)):
|
||||
var_res.append(angles_for_subprocess[j])
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
##var_res=[]
|
||||
##for rot in angels:
|
||||
##img_rot=rotate_image(img_resized,rot)
|
||||
####plt.imshow(img_rot)
|
||||
####plt.show()
|
||||
##img_rot[img_rot!=0]=1
|
||||
##try:
|
||||
##var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
##except:
|
||||
##var_spectrum=0
|
||||
##var_res.append(var_spectrum)
|
||||
try:
|
||||
var_res=np.array(var_res)
|
||||
ang_int=angels[np.argmax(var_res)]#angels_sorted[arg_final]#angels[arg_sort_early[arg_sort[arg_final]]]#angels[arg_fin]
|
||||
|
@ -1700,40 +1800,85 @@ def return_deskew_slop(img_patch_org, sigma_des, main_page=False, plotter=None):
|
|||
|
||||
elif abs(ang_int)>early_slope_edge and ang_int>0:
|
||||
|
||||
angels=np.linspace(90,12,100)
|
||||
angels=np.linspace(90,12,n_tot_angles)
|
||||
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(angels), num_cores + 1)
|
||||
|
||||
for i in range(num_cores):
|
||||
angels_per_process = angels[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_image_rotation, args=(queue_of_all_params, angels_per_process, img_resized, sigma_des)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
var_res=[]
|
||||
for rot in angels:
|
||||
img_rot=rotate_image(img_resized,rot)
|
||||
##plt.imshow(img_rot)
|
||||
##plt.show()
|
||||
img_rot[img_rot!=0]=1
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
#print(indexer,'indexer')
|
||||
except:
|
||||
var_spectrum=0
|
||||
var_res.append(var_spectrum)
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
angles_for_subprocess = list_all_par[0]
|
||||
for j in range(len(angles_for_subprocess)):
|
||||
var_res.append(angles_for_subprocess[j])
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
|
||||
|
||||
###var_res=[]
|
||||
###for rot in angels:
|
||||
###img_rot=rotate_image(img_resized,rot)
|
||||
#####plt.imshow(img_rot)
|
||||
#####plt.show()
|
||||
###img_rot[img_rot!=0]=1
|
||||
###try:
|
||||
###var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
####print(indexer,'indexer')
|
||||
###except:
|
||||
###var_spectrum=0
|
||||
###var_res.append(var_spectrum)
|
||||
try:
|
||||
var_res=np.array(var_res)
|
||||
ang_int=angels[np.argmax(var_res)]#angels_sorted[arg_final]#angels[arg_sort_early[arg_sort[arg_final]]]#angels[arg_fin]
|
||||
except:
|
||||
ang_int=0
|
||||
else:
|
||||
angels=np.linspace(-25,25,60)
|
||||
var_res=[]
|
||||
angels=np.linspace(-25,25,int(n_tot_angles/2.)+10)
|
||||
indexer=0
|
||||
for rot in angels:
|
||||
img_rot=rotate_image(img_resized,rot)
|
||||
#plt.imshow(img_rot)
|
||||
#plt.show()
|
||||
img_rot[img_rot!=0]=1
|
||||
#neg_peaks,var_spectrum=self.find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
#print(var_spectrum,'var_spectrum')
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
except:
|
||||
var_spectrum=0
|
||||
var_res.append(var_spectrum)
|
||||
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(angels), num_cores + 1)
|
||||
|
||||
for i in range(num_cores):
|
||||
angels_per_process = angels[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_image_rotation, args=(queue_of_all_params, angels_per_process, img_resized, sigma_des)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
var_res=[]
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
angles_for_subprocess = list_all_par[0]
|
||||
for j in range(len(angles_for_subprocess)):
|
||||
var_res.append(angles_for_subprocess[j])
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
####var_res=[]
|
||||
|
||||
####for rot in angels:
|
||||
####img_rot=rotate_image(img_resized,rot)
|
||||
#####plt.imshow(img_rot)
|
||||
#####plt.show()
|
||||
####img_rot[img_rot!=0]=1
|
||||
#####neg_peaks,var_spectrum=self.find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
#####print(var_spectrum,'var_spectrum')
|
||||
####try:
|
||||
####var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
####except:
|
||||
####var_spectrum=0
|
||||
####var_res.append(var_spectrum)
|
||||
try:
|
||||
var_res=np.array(var_res)
|
||||
ang_int=angels[np.argmax(var_res)]#angels_sorted[arg_final]#angels[arg_sort_early[arg_sort[arg_final]]]#angels[arg_fin]
|
||||
|
@ -1749,20 +1894,41 @@ def return_deskew_slop(img_patch_org, sigma_des, main_page=False, plotter=None):
|
|||
early_slope_edge=22
|
||||
if abs(ang_int)>early_slope_edge and ang_int<0:
|
||||
|
||||
angels=np.linspace(-90,-25,60)
|
||||
|
||||
angels=np.linspace(-90,-25,int(n_tot_angles/2.)+10)
|
||||
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(angels), num_cores + 1)
|
||||
|
||||
for i in range(num_cores):
|
||||
angels_per_process = angels[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_image_rotation, args=(queue_of_all_params, angels_per_process, img_resized, sigma_des)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
var_res=[]
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
angles_for_subprocess = list_all_par[0]
|
||||
for j in range(len(angles_for_subprocess)):
|
||||
var_res.append(angles_for_subprocess[j])
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
|
||||
for rot in angels:
|
||||
img_rot=rotate_image(img_resized,rot)
|
||||
##plt.imshow(img_rot)
|
||||
##plt.show()
|
||||
img_rot[img_rot!=0]=1
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
except:
|
||||
var_spectrum=0
|
||||
var_res.append(var_spectrum)
|
||||
###var_res=[]
|
||||
|
||||
###for rot in angels:
|
||||
###img_rot=rotate_image(img_resized,rot)
|
||||
#####plt.imshow(img_rot)
|
||||
#####plt.show()
|
||||
###img_rot[img_rot!=0]=1
|
||||
###try:
|
||||
###var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
###except:
|
||||
###var_spectrum=0
|
||||
###var_res.append(var_spectrum)
|
||||
|
||||
try:
|
||||
var_res=np.array(var_res)
|
||||
|
@ -1772,23 +1938,45 @@ def return_deskew_slop(img_patch_org, sigma_des, main_page=False, plotter=None):
|
|||
|
||||
elif abs(ang_int)>early_slope_edge and ang_int>0:
|
||||
|
||||
angels=np.linspace(90,25,60)
|
||||
|
||||
var_res=[]
|
||||
|
||||
angels=np.linspace(90,25,int(n_tot_angles/2.)+10)
|
||||
indexer=0
|
||||
for rot in angels:
|
||||
img_rot=rotate_image(img_resized,rot)
|
||||
##plt.imshow(img_rot)
|
||||
##plt.show()
|
||||
img_rot[img_rot!=0]=1
|
||||
try:
|
||||
var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
#print(indexer,'indexer')
|
||||
except:
|
||||
var_spectrum=0
|
||||
|
||||
queue_of_all_params = Queue()
|
||||
processes = []
|
||||
nh = np.linspace(0, len(angels), num_cores + 1)
|
||||
|
||||
for i in range(num_cores):
|
||||
angels_per_process = angels[int(nh[i]) : int(nh[i + 1])]
|
||||
processes.append(Process(target=do_image_rotation, args=(queue_of_all_params, angels_per_process, img_resized, sigma_des)))
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].start()
|
||||
|
||||
var_res=[]
|
||||
for i in range(num_cores):
|
||||
list_all_par = queue_of_all_params.get(True)
|
||||
angles_for_subprocess = list_all_par[0]
|
||||
for j in range(len(angles_for_subprocess)):
|
||||
var_res.append(angles_for_subprocess[j])
|
||||
|
||||
for i in range(num_cores):
|
||||
processes[i].join()
|
||||
|
||||
var_res.append(var_spectrum)
|
||||
###var_res=[]
|
||||
|
||||
|
||||
###for rot in angels:
|
||||
###img_rot=rotate_image(img_resized,rot)
|
||||
#####plt.imshow(img_rot)
|
||||
#####plt.show()
|
||||
###img_rot[img_rot!=0]=1
|
||||
###try:
|
||||
###var_spectrum=find_num_col_deskew(img_rot,sigma_des,20.3 )
|
||||
####print(indexer,'indexer')
|
||||
###except:
|
||||
###var_spectrum=0
|
||||
|
||||
###var_res.append(var_spectrum)
|
||||
try:
|
||||
var_res=np.array(var_res)
|
||||
ang_int=angels[np.argmax(var_res)]#angels_sorted[arg_final]#angels[arg_sort_early[arg_sort[arg_final]]]#angels[arg_fin]
|
||||
|
|
|
@ -72,7 +72,7 @@ def order_and_id_of_texts(found_polygons_text_region, found_polygons_text_region
|
|||
|
||||
index_of_types_2 = index_of_types[kind_of_texts == 2]
|
||||
indexes_sorted_2 = indexes_sorted[kind_of_texts == 2]
|
||||
|
||||
|
||||
counter = EynollahIdCounter(region_idx=ref_point)
|
||||
for idx_textregion, _ in enumerate(found_polygons_text_region):
|
||||
id_of_texts.append(counter.next_region_id)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# pylint: disable=import-error
|
||||
from pathlib import Path
|
||||
import os.path
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
from .utils.xml import create_page_xml, xml_reading_order
|
||||
from .utils.counter import EynollahIdCounter
|
||||
|
||||
|
@ -12,6 +12,7 @@ from ocrd_models.ocrd_page import (
|
|||
CoordsType,
|
||||
PcGtsType,
|
||||
TextLineType,
|
||||
TextEquivType,
|
||||
TextRegionType,
|
||||
ImageRegionType,
|
||||
TableRegionType,
|
||||
|
@ -93,11 +94,13 @@ class EynollahXmlWriter():
|
|||
points_co += ' '
|
||||
coords.set_points(points_co[:-1])
|
||||
|
||||
def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter):
|
||||
def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion):
|
||||
self.logger.debug('enter serialize_lines_in_region')
|
||||
for j in range(len(all_found_textline_polygons[region_idx])):
|
||||
coords = CoordsType()
|
||||
textline = TextLineType(id=counter.next_line_id, Coords=coords)
|
||||
if ocr_all_textlines_textregion:
|
||||
textline.set_TextEquiv( [ TextEquivType(Unicode=ocr_all_textlines_textregion[j]) ] )
|
||||
text_region.add_TextLine(textline)
|
||||
region_bboxes = all_box_coord[region_idx]
|
||||
points_co = ''
|
||||
|
@ -133,6 +136,29 @@ class EynollahXmlWriter():
|
|||
points_co += str(int((contour_textline[0][1] + region_bboxes[0]+page_coord[0])/self.scale_y))
|
||||
points_co += ' '
|
||||
coords.set_points(points_co[:-1])
|
||||
|
||||
def serialize_lines_in_dropcapital(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion):
|
||||
self.logger.debug('enter serialize_lines_in_region')
|
||||
for j in range(1):
|
||||
coords = CoordsType()
|
||||
textline = TextLineType(id=counter.next_line_id, Coords=coords)
|
||||
if ocr_all_textlines_textregion:
|
||||
textline.set_TextEquiv( [ TextEquivType(Unicode=ocr_all_textlines_textregion[j]) ] )
|
||||
text_region.add_TextLine(textline)
|
||||
#region_bboxes = all_box_coord[region_idx]
|
||||
points_co = ''
|
||||
for idx_contour_textline, contour_textline in enumerate(all_found_textline_polygons[j]):
|
||||
if len(contour_textline) == 2:
|
||||
points_co += str(int((contour_textline[0] + page_coord[2]) / self.scale_x))
|
||||
points_co += ','
|
||||
points_co += str(int((contour_textline[1] + page_coord[0]) / self.scale_y))
|
||||
else:
|
||||
points_co += str(int((contour_textline[0][0] + page_coord[2]) / self.scale_x))
|
||||
points_co += ','
|
||||
points_co += str(int((contour_textline[0][1] + page_coord[0])/self.scale_y))
|
||||
|
||||
points_co += ' '
|
||||
coords.set_points(points_co[:-1])
|
||||
|
||||
def write_pagexml(self, pcgts):
|
||||
out_fname = os.path.join(self.dir_out, self.image_filename_stem) + ".xml"
|
||||
|
@ -140,7 +166,7 @@ class EynollahXmlWriter():
|
|||
with open(out_fname, 'w') as f:
|
||||
f.write(to_xml(pcgts))
|
||||
|
||||
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables):
|
||||
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines):
|
||||
self.logger.debug('enter build_pagexml_no_full_layout')
|
||||
|
||||
# create the file structure
|
||||
|
@ -159,7 +185,11 @@ class EynollahXmlWriter():
|
|||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)),
|
||||
)
|
||||
page.add_TextRegion(textregion)
|
||||
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter)
|
||||
if ocr_all_textlines:
|
||||
ocr_textlines = ocr_all_textlines[mm]
|
||||
else:
|
||||
ocr_textlines = None
|
||||
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter, ocr_textlines)
|
||||
|
||||
for mm in range(len(found_polygons_marginals)):
|
||||
marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',
|
||||
|
@ -209,7 +239,7 @@ class EynollahXmlWriter():
|
|||
|
||||
return pcgts
|
||||
|
||||
def build_pagexml_full_layout(self, found_polygons_text_region, found_polygons_text_region_h, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml):
|
||||
def build_pagexml_full_layout(self, found_polygons_text_region, found_polygons_text_region_h, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, found_polygons_text_region_img, found_polygons_tables, found_polygons_drop_capitals, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, ocr_all_textlines):
|
||||
self.logger.debug('enter build_pagexml_full_layout')
|
||||
|
||||
# create the file structure
|
||||
|
@ -226,14 +256,24 @@ class EynollahXmlWriter():
|
|||
textregion = TextRegionType(id=counter.next_region_id, type_='paragraph',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)))
|
||||
page.add_TextRegion(textregion)
|
||||
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter)
|
||||
|
||||
if ocr_all_textlines:
|
||||
ocr_textlines = ocr_all_textlines[mm]
|
||||
else:
|
||||
ocr_textlines = None
|
||||
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter, ocr_textlines)
|
||||
|
||||
self.logger.debug('len(found_polygons_text_region_h) %s', len(found_polygons_text_region_h))
|
||||
for mm in range(len(found_polygons_text_region_h)):
|
||||
textregion = TextRegionType(id=counter.next_region_id, type_='header',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region_h[mm], page_coord)))
|
||||
page.add_TextRegion(textregion)
|
||||
self.serialize_lines_in_region(textregion, all_found_textline_polygons_h, mm, page_coord, all_box_coord_h, slopes_h, counter)
|
||||
|
||||
if ocr_all_textlines:
|
||||
ocr_textlines = ocr_all_textlines[mm]
|
||||
else:
|
||||
ocr_textlines = None
|
||||
self.serialize_lines_in_region(textregion, all_found_textline_polygons_h, mm, page_coord, all_box_coord_h, slopes_h, counter, ocr_textlines)
|
||||
|
||||
for mm in range(len(found_polygons_marginals)):
|
||||
marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',
|
||||
|
@ -242,8 +282,12 @@ class EynollahXmlWriter():
|
|||
self.serialize_lines_in_marginal(marginal, all_found_textline_polygons_marginals, mm, page_coord, all_box_coord_marginals, slopes_marginals, counter)
|
||||
|
||||
for mm in range(len(found_polygons_drop_capitals)):
|
||||
page.add_TextRegion(TextRegionType(id=counter.next_region_id, type_='drop-capital',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_drop_capitals[mm], page_coord))))
|
||||
dropcapital = TextRegionType(id=counter.next_region_id, type_='drop-capital',
|
||||
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_drop_capitals[mm], page_coord)))
|
||||
page.add_TextRegion(dropcapital)
|
||||
all_box_coord_drop = None
|
||||
slopes_drop = None
|
||||
self.serialize_lines_in_dropcapital(dropcapital, [found_polygons_drop_capitals[mm]], mm, page_coord, all_box_coord_drop, slopes_drop, counter, ocr_all_textlines_textregion=None)
|
||||
|
||||
for mm in range(len(found_polygons_text_region_img)):
|
||||
page.add_ImageRegion(ImageRegionType(id=counter.next_region_id, Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region_img[mm], page_coord))))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue