mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-08-29 11:59:55 +02:00
I have tried to address the issues #163 and #161 . The changes have also improved marginal detection and enhanced the isolation of headers.
This commit is contained in:
parent
5d447abcc4
commit
02a679a145
3 changed files with 275 additions and 34 deletions
|
@ -4,4 +4,5 @@ numpy <1.24.0
|
||||||
scikit-learn >= 0.23.2
|
scikit-learn >= 0.23.2
|
||||||
tensorflow < 2.13
|
tensorflow < 2.13
|
||||||
numba <= 0.58.1
|
numba <= 0.58.1
|
||||||
|
scikit-image
|
||||||
loky
|
loky
|
||||||
|
|
|
@ -235,6 +235,16 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out)
|
||||||
"-ncl",
|
"-ncl",
|
||||||
help="upper limit of columns in document image",
|
help="upper limit of columns in document image",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--threshold_art_class_layout",
|
||||||
|
"-tharl",
|
||||||
|
help="threshold of artifical class in the case of layout detection",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--threshold_art_class_textline",
|
||||||
|
"-thart",
|
||||||
|
help="threshold of artifical class in the case of textline detection",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--skip_layout_and_reading_order",
|
"--skip_layout_and_reading_order",
|
||||||
"-slro/-noslro",
|
"-slro/-noslro",
|
||||||
|
@ -248,7 +258,7 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out)
|
||||||
help="Override log level globally to this",
|
help="Override log level globally to this",
|
||||||
)
|
)
|
||||||
|
|
||||||
def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, num_col_upper, num_col_lower, skip_layout_and_reading_order, ignore_page_extraction, log_level):
|
def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_deskewed, save_all, extract_only_images, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, num_col_upper, num_col_lower, threshold_art_class_textline, threshold_art_class_layout, skip_layout_and_reading_order, ignore_page_extraction, log_level):
|
||||||
initLogging()
|
initLogging()
|
||||||
if log_level:
|
if log_level:
|
||||||
getLogger('eynollah').setLevel(getLevelName(log_level))
|
getLogger('eynollah').setLevel(getLevelName(log_level))
|
||||||
|
@ -298,6 +308,8 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
||||||
num_col_upper=num_col_upper,
|
num_col_upper=num_col_upper,
|
||||||
num_col_lower=num_col_lower,
|
num_col_lower=num_col_lower,
|
||||||
skip_layout_and_reading_order=skip_layout_and_reading_order,
|
skip_layout_and_reading_order=skip_layout_and_reading_order,
|
||||||
|
threshold_art_class_textline=threshold_art_class_textline,
|
||||||
|
threshold_art_class_layout=threshold_art_class_layout,
|
||||||
)
|
)
|
||||||
if dir_in:
|
if dir_in:
|
||||||
eynollah.run(dir_in=dir_in, overwrite=overwrite)
|
eynollah.run(dir_in=dir_in, overwrite=overwrite)
|
||||||
|
|
|
@ -30,7 +30,7 @@ import numpy as np
|
||||||
from scipy.signal import find_peaks
|
from scipy.signal import find_peaks
|
||||||
from scipy.ndimage import gaussian_filter1d
|
from scipy.ndimage import gaussian_filter1d
|
||||||
from numba import cuda
|
from numba import cuda
|
||||||
|
from skimage.morphology import skeletonize
|
||||||
from ocrd import OcrdPage
|
from ocrd import OcrdPage
|
||||||
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
||||||
|
|
||||||
|
@ -200,6 +200,8 @@ class Eynollah:
|
||||||
do_ocr : bool = False,
|
do_ocr : bool = False,
|
||||||
num_col_upper : Optional[int] = None,
|
num_col_upper : Optional[int] = None,
|
||||||
num_col_lower : Optional[int] = None,
|
num_col_lower : Optional[int] = None,
|
||||||
|
threshold_art_class_layout: Optional[float] = None,
|
||||||
|
threshold_art_class_textline: Optional[float] = None,
|
||||||
skip_layout_and_reading_order : bool = False,
|
skip_layout_and_reading_order : bool = False,
|
||||||
logger : Optional[Logger] = None,
|
logger : Optional[Logger] = None,
|
||||||
):
|
):
|
||||||
|
@ -237,6 +239,17 @@ class Eynollah:
|
||||||
self.num_col_lower = int(num_col_lower)
|
self.num_col_lower = int(num_col_lower)
|
||||||
else:
|
else:
|
||||||
self.num_col_lower = num_col_lower
|
self.num_col_lower = num_col_lower
|
||||||
|
|
||||||
|
if threshold_art_class_layout:
|
||||||
|
self.threshold_art_class_layout = float(threshold_art_class_layout)
|
||||||
|
else:
|
||||||
|
self.threshold_art_class_layout = 0.1
|
||||||
|
|
||||||
|
if threshold_art_class_textline:
|
||||||
|
self.threshold_art_class_textline = float(threshold_art_class_textline)
|
||||||
|
else:
|
||||||
|
self.threshold_art_class_textline = 0.1
|
||||||
|
|
||||||
self.logger = logger if logger else getLogger('eynollah')
|
self.logger = logger if logger else getLogger('eynollah')
|
||||||
# for parallelization of CPU-intensive tasks:
|
# for parallelization of CPU-intensive tasks:
|
||||||
self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200)
|
self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200)
|
||||||
|
@ -784,7 +797,7 @@ class Eynollah:
|
||||||
self, patches, img, model,
|
self, patches, img, model,
|
||||||
n_batch_inference=1, marginal_of_patch_percent=0.1,
|
n_batch_inference=1, marginal_of_patch_percent=0.1,
|
||||||
thresholding_for_some_classes_in_light_version=False,
|
thresholding_for_some_classes_in_light_version=False,
|
||||||
thresholding_for_artificial_class_in_light_version=False):
|
thresholding_for_artificial_class_in_light_version=False, threshold_art_class_textline=0.1):
|
||||||
|
|
||||||
self.logger.debug("enter do_prediction")
|
self.logger.debug("enter do_prediction")
|
||||||
img_height_model = model.layers[-1].output_shape[1]
|
img_height_model = model.layers[-1].output_shape[1]
|
||||||
|
@ -802,10 +815,13 @@ class Eynollah:
|
||||||
if thresholding_for_artificial_class_in_light_version:
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
seg_art = label_p_pred[0,:,:,2]
|
seg_art = label_p_pred[0,:,:,2]
|
||||||
|
|
||||||
seg_art[seg_art<0.2] = 0
|
seg_art[seg_art<threshold_art_class_textline] = 0
|
||||||
seg_art[seg_art>0] =1
|
seg_art[seg_art>0] =1
|
||||||
|
|
||||||
|
skeleton_art = skeletonize(seg_art)
|
||||||
|
skeleton_art = skeleton_art*1
|
||||||
|
|
||||||
seg[seg_art==1]=2
|
seg[skeleton_art==1]=2
|
||||||
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8)
|
prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8)
|
||||||
return prediction_true
|
return prediction_true
|
||||||
|
@ -896,14 +912,17 @@ class Eynollah:
|
||||||
if thresholding_for_artificial_class_in_light_version:
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
seg_art = label_p_pred[:,:,:,2]
|
seg_art = label_p_pred[:,:,:,2]
|
||||||
|
|
||||||
seg_art[seg_art<0.2] = 0
|
seg_art[seg_art<threshold_art_class_textline] = 0
|
||||||
seg_art[seg_art>0] =1
|
seg_art[seg_art>0] =1
|
||||||
|
|
||||||
seg[seg_art==1]=2
|
##seg[seg_art==1]=2
|
||||||
|
|
||||||
indexer_inside_batch = 0
|
indexer_inside_batch = 0
|
||||||
for i_batch, j_batch in zip(list_i_s, list_j_s):
|
for i_batch, j_batch in zip(list_i_s, list_j_s):
|
||||||
seg_in = seg[indexer_inside_batch]
|
seg_in = seg[indexer_inside_batch]
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
seg_in_art = seg_art[indexer_inside_batch]
|
||||||
|
|
||||||
index_y_u_in = list_y_u[indexer_inside_batch]
|
index_y_u_in = list_y_u[indexer_inside_batch]
|
||||||
index_y_d_in = list_y_d[indexer_inside_batch]
|
index_y_d_in = list_y_d[indexer_inside_batch]
|
||||||
|
@ -917,54 +936,107 @@ class Eynollah:
|
||||||
seg_in[0:-margin or None,
|
seg_in[0:-margin or None,
|
||||||
0:-margin or None,
|
0:-margin or None,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
0:-margin or None]
|
||||||
|
|
||||||
elif i_batch == nxf - 1 and j_batch == nyf - 1:
|
elif i_batch == nxf - 1 and j_batch == nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
index_x_d_in + margin:index_x_u_in - 0] = \
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
seg_in[margin:,
|
seg_in[margin:,
|
||||||
margin:,
|
margin:,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
margin:]
|
||||||
|
|
||||||
elif i_batch == 0 and j_batch == nyf - 1:
|
elif i_batch == 0 and j_batch == nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
index_x_d_in + 0:index_x_u_in - margin] = \
|
index_x_d_in + 0:index_x_u_in - margin] = \
|
||||||
seg_in[margin:,
|
seg_in[margin:,
|
||||||
0:-margin or None,
|
0:-margin or None,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
0:-margin or None]
|
||||||
|
|
||||||
elif i_batch == nxf - 1 and j_batch == 0:
|
elif i_batch == nxf - 1 and j_batch == 0:
|
||||||
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
index_x_d_in + margin:index_x_u_in - 0] = \
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
seg_in[0:-margin or None,
|
seg_in[0:-margin or None,
|
||||||
margin:,
|
margin:,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
margin:]
|
||||||
|
|
||||||
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
|
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
index_x_d_in + 0:index_x_u_in - margin] = \
|
index_x_d_in + 0:index_x_u_in - margin] = \
|
||||||
seg_in[margin:-margin or None,
|
seg_in[margin:-margin or None,
|
||||||
0:-margin or None,
|
0:-margin or None,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
0:-margin or None]
|
||||||
|
|
||||||
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
|
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
index_x_d_in + margin:index_x_u_in - 0] = \
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
seg_in[margin:-margin or None,
|
seg_in[margin:-margin or None,
|
||||||
margin:,
|
margin:,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
margin:]
|
||||||
|
|
||||||
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
|
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
|
||||||
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
index_x_d_in + margin:index_x_u_in - margin] = \
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
seg_in[0:-margin or None,
|
seg_in[0:-margin or None,
|
||||||
margin:-margin or None,
|
margin:-margin or None,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
margin:-margin or None]
|
||||||
|
|
||||||
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
|
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
index_x_d_in + margin:index_x_u_in - margin] = \
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
seg_in[margin:,
|
seg_in[margin:,
|
||||||
margin:-margin or None,
|
margin:-margin or None,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
margin:-margin or None]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
index_x_d_in + margin:index_x_u_in - margin] = \
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
seg_in[margin:-margin or None,
|
seg_in[margin:-margin or None,
|
||||||
margin:-margin or None,
|
margin:-margin or None,
|
||||||
np.newaxis]
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
margin:-margin or None]
|
||||||
indexer_inside_batch += 1
|
indexer_inside_batch += 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -979,6 +1051,19 @@ class Eynollah:
|
||||||
img_patch[:] = 0
|
img_patch[:] = 0
|
||||||
|
|
||||||
prediction_true = prediction_true.astype(np.uint8)
|
prediction_true = prediction_true.astype(np.uint8)
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
kernel_min = np.ones((3, 3), np.uint8)
|
||||||
|
prediction_true[:,:,0][prediction_true[:,:,0]==2] = 0
|
||||||
|
|
||||||
|
skeleton_art = skeletonize(prediction_true[:,:,1])
|
||||||
|
skeleton_art = skeleton_art*1
|
||||||
|
|
||||||
|
skeleton_art = skeleton_art.astype('uint8')
|
||||||
|
|
||||||
|
skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1)
|
||||||
|
|
||||||
|
prediction_true[:,:,0][skeleton_art==1]=2
|
||||||
#del model
|
#del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
return prediction_true
|
return prediction_true
|
||||||
|
@ -1117,7 +1202,7 @@ class Eynollah:
|
||||||
self, patches, img, model,
|
self, patches, img, model,
|
||||||
n_batch_inference=1, marginal_of_patch_percent=0.1,
|
n_batch_inference=1, marginal_of_patch_percent=0.1,
|
||||||
thresholding_for_some_classes_in_light_version=False,
|
thresholding_for_some_classes_in_light_version=False,
|
||||||
thresholding_for_artificial_class_in_light_version=False):
|
thresholding_for_artificial_class_in_light_version=False, threshold_art_class_textline=0.1, threshold_art_class_layout=0.1):
|
||||||
|
|
||||||
self.logger.debug("enter do_prediction_new_concept")
|
self.logger.debug("enter do_prediction_new_concept")
|
||||||
img_height_model = model.layers[-1].output_shape[1]
|
img_height_model = model.layers[-1].output_shape[1]
|
||||||
|
@ -1132,19 +1217,28 @@ class Eynollah:
|
||||||
label_p_pred = model.predict(img[np.newaxis], verbose=0)
|
label_p_pred = model.predict(img[np.newaxis], verbose=0)
|
||||||
seg = np.argmax(label_p_pred, axis=3)[0]
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||||
|
|
||||||
if thresholding_for_artificial_class_in_light_version:
|
|
||||||
#seg_text = label_p_pred[0,:,:,1]
|
|
||||||
#seg_text[seg_text<0.2] =0
|
|
||||||
#seg_text[seg_text>0] =1
|
|
||||||
#seg[seg_text==1]=1
|
|
||||||
|
|
||||||
seg_art = label_p_pred[0,:,:,4]
|
|
||||||
seg_art[seg_art<0.2] =0
|
|
||||||
seg_art[seg_art>0] =1
|
|
||||||
seg[seg_art==1]=4
|
|
||||||
|
|
||||||
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8)
|
prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8)
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
kernel_min = np.ones((3, 3), np.uint8)
|
||||||
|
seg_art = label_p_pred[0,:,:,4]
|
||||||
|
seg_art[seg_art<threshold_art_class_layout] =0
|
||||||
|
seg_art[seg_art>0] =1
|
||||||
|
#seg[seg_art==1]=4
|
||||||
|
seg_art = resize_image(seg_art, img_h_page, img_w_page).astype(np.uint8)
|
||||||
|
|
||||||
|
prediction_true[:,:,0][prediction_true[:,:,0]==4] = 0
|
||||||
|
|
||||||
|
skeleton_art = skeletonize(seg_art)
|
||||||
|
skeleton_art = skeleton_art*1
|
||||||
|
|
||||||
|
skeleton_art = skeleton_art.astype('uint8')
|
||||||
|
|
||||||
|
skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1)
|
||||||
|
|
||||||
|
prediction_true[:,:,0][skeleton_art==1] = 4
|
||||||
|
|
||||||
return prediction_true , resize_image(label_p_pred[0, :, :, 1] , img_h_page, img_w_page)
|
return prediction_true , resize_image(label_p_pred[0, :, :, 1] , img_h_page, img_w_page)
|
||||||
|
|
||||||
if img.shape[0] < img_height_model:
|
if img.shape[0] < img_height_model:
|
||||||
|
@ -1217,26 +1311,29 @@ class Eynollah:
|
||||||
|
|
||||||
if thresholding_for_some_classes_in_light_version:
|
if thresholding_for_some_classes_in_light_version:
|
||||||
seg_art = label_p_pred[:,:,:,4]
|
seg_art = label_p_pred[:,:,:,4]
|
||||||
seg_art[seg_art<0.2] =0
|
seg_art[seg_art<threshold_art_class_layout] =0
|
||||||
seg_art[seg_art>0] =1
|
seg_art[seg_art>0] =1
|
||||||
|
|
||||||
seg_line = label_p_pred[:,:,:,3]
|
seg_line = label_p_pred[:,:,:,3]
|
||||||
seg_line[seg_line>0.5] =1#seg_line[seg_line>0.1] =1
|
seg_line[seg_line>0.5] =1#seg_line[seg_line>0.1] =1
|
||||||
seg_line[seg_line<1] =0
|
seg_line[seg_line<1] =0
|
||||||
|
|
||||||
seg[seg_art==1]=4
|
##seg[seg_art==1]=4
|
||||||
seg[(seg_line==1) & (seg==0)]=3
|
seg[(seg_line==1) & (seg==0)]=3
|
||||||
if thresholding_for_artificial_class_in_light_version:
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
seg_art = label_p_pred[:,:,:,2]
|
seg_art = label_p_pred[:,:,:,2]
|
||||||
|
|
||||||
seg_art[seg_art<0.2] = 0
|
seg_art[seg_art<threshold_art_class_textline] = 0
|
||||||
seg_art[seg_art>0] =1
|
seg_art[seg_art>0] =1
|
||||||
|
|
||||||
seg[seg_art==1]=2
|
##seg[seg_art==1]=2
|
||||||
|
|
||||||
indexer_inside_batch = 0
|
indexer_inside_batch = 0
|
||||||
for i_batch, j_batch in zip(list_i_s, list_j_s):
|
for i_batch, j_batch in zip(list_i_s, list_j_s):
|
||||||
seg_in = seg[indexer_inside_batch]
|
seg_in = seg[indexer_inside_batch]
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
seg_in_art = seg_art[indexer_inside_batch]
|
||||||
|
|
||||||
index_y_u_in = list_y_u[indexer_inside_batch]
|
index_y_u_in = list_y_u[indexer_inside_batch]
|
||||||
index_y_d_in = list_y_d[indexer_inside_batch]
|
index_y_d_in = list_y_d[indexer_inside_batch]
|
||||||
|
@ -1255,6 +1352,12 @@ class Eynollah:
|
||||||
label_p_pred[0, 0:-margin or None,
|
label_p_pred[0, 0:-margin or None,
|
||||||
0:-margin or None,
|
0:-margin or None,
|
||||||
1]
|
1]
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
0:-margin or None]
|
||||||
|
|
||||||
elif i_batch == nxf - 1 and j_batch == nyf - 1:
|
elif i_batch == nxf - 1 and j_batch == nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
index_x_d_in + margin:index_x_u_in - 0] = \
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
|
@ -1266,6 +1369,12 @@ class Eynollah:
|
||||||
label_p_pred[0, margin:,
|
label_p_pred[0, margin:,
|
||||||
margin:,
|
margin:,
|
||||||
1]
|
1]
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
margin:]
|
||||||
|
|
||||||
elif i_batch == 0 and j_batch == nyf - 1:
|
elif i_batch == 0 and j_batch == nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
index_x_d_in + 0:index_x_u_in - margin] = \
|
index_x_d_in + 0:index_x_u_in - margin] = \
|
||||||
|
@ -1277,6 +1386,13 @@ class Eynollah:
|
||||||
label_p_pred[0, margin:,
|
label_p_pred[0, margin:,
|
||||||
0:-margin or None,
|
0:-margin or None,
|
||||||
1]
|
1]
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
0:-margin or None]
|
||||||
|
|
||||||
elif i_batch == nxf - 1 and j_batch == 0:
|
elif i_batch == nxf - 1 and j_batch == 0:
|
||||||
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
index_x_d_in + margin:index_x_u_in - 0] = \
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
|
@ -1288,6 +1404,12 @@ class Eynollah:
|
||||||
label_p_pred[0, 0:-margin or None,
|
label_p_pred[0, 0:-margin or None,
|
||||||
margin:,
|
margin:,
|
||||||
1]
|
1]
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
margin:]
|
||||||
|
|
||||||
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
|
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
index_x_d_in + 0:index_x_u_in - margin] = \
|
index_x_d_in + 0:index_x_u_in - margin] = \
|
||||||
|
@ -1299,6 +1421,11 @@ class Eynollah:
|
||||||
label_p_pred[0, margin:-margin or None,
|
label_p_pred[0, margin:-margin or None,
|
||||||
0:-margin or None,
|
0:-margin or None,
|
||||||
1]
|
1]
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
0:-margin or None]
|
||||||
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
|
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
index_x_d_in + margin:index_x_u_in - 0] = \
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
|
@ -1310,6 +1437,11 @@ class Eynollah:
|
||||||
label_p_pred[0, margin:-margin or None,
|
label_p_pred[0, margin:-margin or None,
|
||||||
margin:,
|
margin:,
|
||||||
1]
|
1]
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
margin:]
|
||||||
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
|
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
|
||||||
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
index_x_d_in + margin:index_x_u_in - margin] = \
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
|
@ -1321,6 +1453,11 @@ class Eynollah:
|
||||||
label_p_pred[0, 0:-margin or None,
|
label_p_pred[0, 0:-margin or None,
|
||||||
margin:-margin or None,
|
margin:-margin or None,
|
||||||
1]
|
1]
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
margin:-margin or None]
|
||||||
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
|
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
index_x_d_in + margin:index_x_u_in - margin] = \
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
|
@ -1332,6 +1469,11 @@ class Eynollah:
|
||||||
label_p_pred[0, margin:,
|
label_p_pred[0, margin:,
|
||||||
margin:-margin or None,
|
margin:-margin or None,
|
||||||
1]
|
1]
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
margin:-margin or None]
|
||||||
else:
|
else:
|
||||||
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
index_x_d_in + margin:index_x_u_in - margin] = \
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
|
@ -1343,6 +1485,11 @@ class Eynollah:
|
||||||
label_p_pred[0, margin:-margin or None,
|
label_p_pred[0, margin:-margin or None,
|
||||||
margin:-margin or None,
|
margin:-margin or None,
|
||||||
1]
|
1]
|
||||||
|
if thresholding_for_artificial_class_in_light_version or thresholding_for_some_classes_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
margin:-margin or None]
|
||||||
indexer_inside_batch += 1
|
indexer_inside_batch += 1
|
||||||
|
|
||||||
list_i_s = []
|
list_i_s = []
|
||||||
|
@ -1356,6 +1503,32 @@ class Eynollah:
|
||||||
img_patch[:] = 0
|
img_patch[:] = 0
|
||||||
|
|
||||||
prediction_true = prediction_true.astype(np.uint8)
|
prediction_true = prediction_true.astype(np.uint8)
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
kernel_min = np.ones((3, 3), np.uint8)
|
||||||
|
prediction_true[:,:,0][prediction_true[:,:,0]==2] = 0
|
||||||
|
|
||||||
|
skeleton_art = skeletonize(prediction_true[:,:,1])
|
||||||
|
skeleton_art = skeleton_art*1
|
||||||
|
|
||||||
|
skeleton_art = skeleton_art.astype('uint8')
|
||||||
|
|
||||||
|
skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1)
|
||||||
|
|
||||||
|
prediction_true[:,:,0][skeleton_art==1]=2
|
||||||
|
|
||||||
|
if thresholding_for_some_classes_in_light_version:
|
||||||
|
kernel_min = np.ones((3, 3), np.uint8)
|
||||||
|
prediction_true[:,:,0][prediction_true[:,:,0]==4] = 0
|
||||||
|
|
||||||
|
skeleton_art = skeletonize(prediction_true[:,:,1])
|
||||||
|
skeleton_art = skeleton_art*1
|
||||||
|
|
||||||
|
skeleton_art = skeleton_art.astype('uint8')
|
||||||
|
|
||||||
|
skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1)
|
||||||
|
|
||||||
|
prediction_true[:,:,0][skeleton_art==1]=4
|
||||||
gc.collect()
|
gc.collect()
|
||||||
return prediction_true, confidence_matrix
|
return prediction_true, confidence_matrix
|
||||||
|
|
||||||
|
@ -1608,7 +1781,7 @@ class Eynollah:
|
||||||
prediction_textline = self.do_prediction(
|
prediction_textline = self.do_prediction(
|
||||||
use_patches, img, self.model_textline,
|
use_patches, img, self.model_textline,
|
||||||
marginal_of_patch_percent=0.15, n_batch_inference=3,
|
marginal_of_patch_percent=0.15, n_batch_inference=3,
|
||||||
thresholding_for_artificial_class_in_light_version=self.textline_light)
|
thresholding_for_artificial_class_in_light_version=self.textline_light, threshold_art_class_textline=self.threshold_art_class_textline)
|
||||||
#if not self.textline_light:
|
#if not self.textline_light:
|
||||||
#if num_col_classifier==1:
|
#if num_col_classifier==1:
|
||||||
#prediction_textline_nopatch = self.do_prediction(False, img, self.model_textline)
|
#prediction_textline_nopatch = self.do_prediction(False, img, self.model_textline)
|
||||||
|
@ -1622,7 +1795,55 @@ class Eynollah:
|
||||||
textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8')
|
textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8')
|
||||||
#textline_mask_tot_ea_art = cv2.dilate(textline_mask_tot_ea_art, KERNEL, iterations=1)
|
#textline_mask_tot_ea_art = cv2.dilate(textline_mask_tot_ea_art, KERNEL, iterations=1)
|
||||||
prediction_textline[:,:][textline_mask_tot_ea_art[:,:]==1]=2
|
prediction_textline[:,:][textline_mask_tot_ea_art[:,:]==1]=2
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8')
|
||||||
|
hor_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (8, 1))
|
||||||
|
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
|
||||||
|
##cv2.imwrite('textline_mask_tot_ea_art.png', textline_mask_tot_ea_art)
|
||||||
|
textline_mask_tot_ea_art = cv2.dilate(textline_mask_tot_ea_art, hor_kernel, iterations=1)
|
||||||
|
|
||||||
|
###cv2.imwrite('dil_textline_mask_tot_ea_art.png', dil_textline_mask_tot_ea_art)
|
||||||
|
|
||||||
|
textline_mask_tot_ea_art = textline_mask_tot_ea_art.astype('uint8')
|
||||||
|
|
||||||
|
#print(np.shape(dil_textline_mask_tot_ea_art), np.unique(dil_textline_mask_tot_ea_art), 'dil_textline_mask_tot_ea_art')
|
||||||
|
tsk = time.time()
|
||||||
|
skeleton_art_textline = skeletonize(textline_mask_tot_ea_art[:,:,0])
|
||||||
|
|
||||||
|
skeleton_art_textline = skeleton_art_textline*1
|
||||||
|
|
||||||
|
skeleton_art_textline = skeleton_art_textline.astype('uint8')
|
||||||
|
|
||||||
|
skeleton_art_textline = cv2.dilate(skeleton_art_textline, kernel, iterations=1)
|
||||||
|
|
||||||
|
#print(np.unique(skeleton_art_textline), np.shape(skeleton_art_textline))
|
||||||
|
|
||||||
|
#print(skeleton_art_textline, np.unique(skeleton_art_textline))
|
||||||
|
|
||||||
|
#cv2.imwrite('skeleton_art_textline.png', skeleton_art_textline)
|
||||||
|
|
||||||
|
|
||||||
|
prediction_textline[:,:,0][skeleton_art_textline[:,:]==1]=2
|
||||||
|
|
||||||
|
#cv2.imwrite('prediction_textline1.png', prediction_textline[:,:,0])
|
||||||
|
|
||||||
|
##hor_kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (4, 1))
|
||||||
|
##ver_kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 3))
|
||||||
|
##textline_mask_tot_ea_main = (prediction_textline[:,:]==1)*1
|
||||||
|
##textline_mask_tot_ea_main = textline_mask_tot_ea_main.astype('uint8')
|
||||||
|
|
||||||
|
##dil_textline_mask_tot_ea_main = cv2.erode(textline_mask_tot_ea_main, ver_kernel2, iterations=1)
|
||||||
|
|
||||||
|
##dil_textline_mask_tot_ea_main = cv2.dilate(textline_mask_tot_ea_main, hor_kernel2, iterations=1)
|
||||||
|
|
||||||
|
##dil_textline_mask_tot_ea_main = cv2.dilate(textline_mask_tot_ea_main, ver_kernel2, iterations=1)
|
||||||
|
|
||||||
|
##prediction_textline[:,:][dil_textline_mask_tot_ea_main[:,:]==1]=1
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
textline_mask_tot_ea_lines = (prediction_textline[:,:]==1)*1
|
textline_mask_tot_ea_lines = (prediction_textline[:,:]==1)*1
|
||||||
textline_mask_tot_ea_lines = textline_mask_tot_ea_lines.astype('uint8')
|
textline_mask_tot_ea_lines = textline_mask_tot_ea_lines.astype('uint8')
|
||||||
if not self.textline_light:
|
if not self.textline_light:
|
||||||
|
@ -1631,10 +1852,15 @@ class Eynollah:
|
||||||
prediction_textline[:,:][textline_mask_tot_ea_lines[:,:]==1]=1
|
prediction_textline[:,:][textline_mask_tot_ea_lines[:,:]==1]=1
|
||||||
if not self.textline_light:
|
if not self.textline_light:
|
||||||
prediction_textline[:,:][old_art[:,:]==1]=2
|
prediction_textline[:,:][old_art[:,:]==1]=2
|
||||||
|
|
||||||
|
#cv2.imwrite('prediction_textline2.png', prediction_textline[:,:,0])
|
||||||
|
|
||||||
prediction_textline_longshot = self.do_prediction(False, img, self.model_textline)
|
prediction_textline_longshot = self.do_prediction(False, img, self.model_textline)
|
||||||
prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w)
|
prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w)
|
||||||
|
|
||||||
|
|
||||||
|
#cv2.imwrite('prediction_textline.png', prediction_textline[:,:,0])
|
||||||
|
#sys.exit()
|
||||||
self.logger.debug('exit textline_contours')
|
self.logger.debug('exit textline_contours')
|
||||||
return ((prediction_textline[:, :, 0]==1).astype(np.uint8),
|
return ((prediction_textline[:, :, 0]==1).astype(np.uint8),
|
||||||
(prediction_textline_longshot_true_size[:, :, 0]==1).astype(np.uint8))
|
(prediction_textline_longshot_true_size[:, :, 0]==1).astype(np.uint8))
|
||||||
|
@ -1840,7 +2066,7 @@ class Eynollah:
|
||||||
textline_mask_tot_ea = resize_image(textline_mask_tot_ea,img_height_h, img_width_h )
|
textline_mask_tot_ea = resize_image(textline_mask_tot_ea,img_height_h, img_width_h )
|
||||||
|
|
||||||
#print(self.image_org.shape)
|
#print(self.image_org.shape)
|
||||||
#cv2.imwrite('out_13.png', self.image_page_org_size)
|
#cv2.imwrite('textline.png', textline_mask_tot_ea)
|
||||||
|
|
||||||
#plt.imshwo(self.image_page_org_size)
|
#plt.imshwo(self.image_page_org_size)
|
||||||
#plt.show()
|
#plt.show()
|
||||||
|
@ -1852,13 +2078,13 @@ class Eynollah:
|
||||||
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
|
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
|
||||||
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
|
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
|
||||||
True, img_resized, self.model_region_1_2, n_batch_inference=1,
|
True, img_resized, self.model_region_1_2, n_batch_inference=1,
|
||||||
thresholding_for_some_classes_in_light_version=True)
|
thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout)
|
||||||
else:
|
else:
|
||||||
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3))
|
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3))
|
||||||
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
|
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
|
||||||
prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept(
|
prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept(
|
||||||
False, self.image_page_org_size, self.model_region_1_2, n_batch_inference=1,
|
False, self.image_page_org_size, self.model_region_1_2, n_batch_inference=1,
|
||||||
thresholding_for_artificial_class_in_light_version=True)
|
thresholding_for_artificial_class_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout)
|
||||||
ys = slice(*self.page_coord[0:2])
|
ys = slice(*self.page_coord[0:2])
|
||||||
xs = slice(*self.page_coord[2:4])
|
xs = slice(*self.page_coord[2:4])
|
||||||
prediction_regions_org[ys, xs] = prediction_regions_page
|
prediction_regions_org[ys, xs] = prediction_regions_page
|
||||||
|
@ -1871,7 +2097,7 @@ class Eynollah:
|
||||||
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
|
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
|
||||||
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
|
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
|
||||||
True, img_resized, self.model_region_1_2, n_batch_inference=2,
|
True, img_resized, self.model_region_1_2, n_batch_inference=2,
|
||||||
thresholding_for_some_classes_in_light_version=True)
|
thresholding_for_some_classes_in_light_version=True, threshold_art_class_layout=self.threshold_art_class_layout)
|
||||||
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_region, n_batch_inference=3, thresholding_for_some_classes_in_light_version=True)
|
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_region, n_batch_inference=3, thresholding_for_some_classes_in_light_version=True)
|
||||||
#print("inside 3 ", time.time()-t_in)
|
#print("inside 3 ", time.time()-t_in)
|
||||||
#plt.imshow(prediction_regions_org[:,:,0])
|
#plt.imshow(prediction_regions_org[:,:,0])
|
||||||
|
@ -3811,7 +4037,7 @@ class Eynollah:
|
||||||
if dilation_m1<6:
|
if dilation_m1<6:
|
||||||
dilation_m1 = 6
|
dilation_m1 = 6
|
||||||
#print(dilation_m1, 'dilation_m1')
|
#print(dilation_m1, 'dilation_m1')
|
||||||
dilation_m1 = 6
|
dilation_m1 = 4#6
|
||||||
dilation_m2 = int(dilation_m1/2.) +1
|
dilation_m2 = int(dilation_m1/2.) +1
|
||||||
|
|
||||||
for i in range(len(x_differential)):
|
for i in range(len(x_differential)):
|
||||||
|
@ -4322,6 +4548,8 @@ class Eynollah:
|
||||||
cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(textline_mask_tot_ea)
|
cnt_clean_rot_raw, hir_on_cnt_clean_rot = return_contours_of_image(textline_mask_tot_ea)
|
||||||
all_found_textline_polygons = filter_contours_area_of_image(
|
all_found_textline_polygons = filter_contours_area_of_image(
|
||||||
textline_mask_tot_ea, cnt_clean_rot_raw, hir_on_cnt_clean_rot, max_area=1, min_area=0.00001)
|
textline_mask_tot_ea, cnt_clean_rot_raw, hir_on_cnt_clean_rot, max_area=1, min_area=0.00001)
|
||||||
|
|
||||||
|
all_found_textline_polygons = all_found_textline_polygons[::-1]
|
||||||
|
|
||||||
all_found_textline_polygons=[ all_found_textline_polygons ]
|
all_found_textline_polygons=[ all_found_textline_polygons ]
|
||||||
|
|
||||||
|
@ -4329,8 +4557,8 @@ class Eynollah:
|
||||||
all_found_textline_polygons)
|
all_found_textline_polygons)
|
||||||
all_found_textline_polygons = self.filter_contours_inside_a_bigger_one(
|
all_found_textline_polygons = self.filter_contours_inside_a_bigger_one(
|
||||||
all_found_textline_polygons, None, textline_mask_tot_ea, type_contour="textline")
|
all_found_textline_polygons, None, textline_mask_tot_ea, type_contour="textline")
|
||||||
|
|
||||||
|
|
||||||
order_text_new = [0]
|
order_text_new = [0]
|
||||||
slopes =[0]
|
slopes =[0]
|
||||||
id_of_texts_tot =['region_0001']
|
id_of_texts_tot =['region_0001']
|
||||||
|
@ -4343,7 +4571,7 @@ class Eynollah:
|
||||||
polygons_lines_xml = []
|
polygons_lines_xml = []
|
||||||
contours_tables = []
|
contours_tables = []
|
||||||
ocr_all_textlines = None
|
ocr_all_textlines = None
|
||||||
conf_contours_textregions =None
|
conf_contours_textregions =[0]
|
||||||
pcgts = self.writer.build_pagexml_no_full_layout(
|
pcgts = self.writer.build_pagexml_no_full_layout(
|
||||||
cont_page, page_coord, order_text_new, id_of_texts_tot,
|
cont_page, page_coord, order_text_new, id_of_texts_tot,
|
||||||
all_found_textline_polygons, page_coord, polygons_of_images, polygons_of_marginals,
|
all_found_textline_polygons, page_coord, polygons_of_images, polygons_of_marginals,
|
||||||
|
@ -4905,7 +5133,7 @@ class Eynollah_ocr:
|
||||||
self.b_s = int(batch_size)
|
self.b_s = int(batch_size)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.model_ocr_dir = dir_models + "/model_step_1050000_ocr"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn"
|
self.model_ocr_dir = dir_models + "/model_ens_ocrcnn_125_225"#"/model_step_125000_ocr"#"/model_step_25000_ocr"#"/model_step_1050000_ocr"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn"
|
||||||
model_ocr = load_model(self.model_ocr_dir , compile=False)
|
model_ocr = load_model(self.model_ocr_dir , compile=False)
|
||||||
|
|
||||||
self.prediction_model = tf.keras.models.Model(
|
self.prediction_model = tf.keras.models.Model(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue