image enhancer is integrated

This commit is contained in:
vahidrezanezhad 2025-06-01 15:53:04 +02:00
parent 928a548b70
commit cc36694dfd
3 changed files with 830 additions and 229 deletions

View file

@ -3,6 +3,7 @@ import click
from ocrd_utils import initLogging, getLevelName, getLogger
from eynollah.eynollah import Eynollah, Eynollah_ocr
from eynollah.sbb_binarize import SbbBinarizer
from eynollah.image_enhancer import Enhancer
@click.group()
def main():
@ -70,6 +71,74 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out)
@main.command()
@click.option(
"--image",
"-i",
help="image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--out",
"-o",
help="directory to write output xml data",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--overwrite",
"-O",
help="overwrite (instead of skipping) if output xml exists",
is_flag=True,
)
@click.option(
"--dir_in",
"-di",
help="directory of images",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--model",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)
def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_lower, log_level):
initLogging()
if log_level:
getLogger('enhancement').setLevel(getLevelName(log_level))
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
enhancer_object = Enhancer(
model,
logger=getLogger('enhancement'),
dir_out=out,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
)
if dir_in:
enhancer_object.run(dir_in=dir_in, overwrite=overwrite)
else:
enhancer_object.run(image_filename=image, overwrite=overwrite)
@main.command()
@click.option(

View file

@ -3612,25 +3612,12 @@ class Eynollah:
inference_bs = 3
cv2.imwrite('textregions.png', text_regions_p*50)
cv2.imwrite('sep.png', (text_regions_p[:,:]==6)*255)
ver_kernel = np.ones((5, 1), dtype=np.uint8)
hor_kernel = np.ones((1, 5), dtype=np.uint8)
#separators = (text_regions_p[:,:]==6)*1
#text_regions_p[text_regions_p[:,:]==6] = 0
#separators = separators.astype('uint8')
#separators = cv2.erode(separators , hor_kernel, iterations=1)
#text_regions_p[separators[:,:]==1] = 6
#cv2.imwrite('sep_new.png', (text_regions_p[:,:]==6)*255)
min_cont_size_to_be_dilated = 10
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _ = find_new_features_of_contours(contours_only_text_parent)
args_cont_located = np.array(range(len(contours_only_text_parent)))
@ -3672,7 +3659,6 @@ class Eynollah:
text_regions_p_textregions_dilated = cv2.dilate(text_regions_p_textregions_dilated , ver_kernel, iterations=5)
text_regions_p_textregions_dilated[text_regions_p[:,:]>1] = 0
cv2.imwrite('text_regions_p_textregions_dilated.png', text_regions_p_textregions_dilated*255)
contours_only_dilated, hir_on_text_dilated = return_contours_of_image(text_regions_p_textregions_dilated)
contours_only_dilated = return_parent_contours(contours_only_dilated, hir_on_text_dilated)
@ -3723,21 +3709,20 @@ class Eynollah:
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,
int(x_min_main[j]):int(x_max_main[j])] = 1
co_text_all_org = contours_only_text_parent + contours_only_text_parent_h
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
co_text_all = contours_only_dilated + contours_only_text_parent_h
else:
co_text_all = contours_only_text_parent + contours_only_text_parent_h
else:
co_text_all_org = contours_only_text_parent
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
co_text_all = contours_only_dilated
else:
co_text_all = contours_only_text_parent
if not len(co_text_all):
return [], []
print(len(co_text_all), "co_text_all")
print(len(co_text_all_org), "co_text_all_org")
labels_con = np.zeros((int(y_len /6.), int(x_len/6.), len(co_text_all)), dtype=bool)
co_text_all = [(i/6).astype(int) for i in co_text_all]
for i in range(len(co_text_all)):
@ -3805,7 +3790,7 @@ class Eynollah:
ordered = [i[0] for i in ordered]
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
org_contours_indexes = []
for ind in range(len(ordered)):
region_with_curr_order = ordered[ind]
@ -3823,215 +3808,6 @@ class Eynollah:
else:
region_ids = ['region_%04d' % i for i in range(len(co_text_all_org))]
return ordered, region_ids
####def return_start_and_end_of_common_text_of_textline_ocr(self, textline_image, ind_tot):
####width = np.shape(textline_image)[1]
####height = np.shape(textline_image)[0]
####common_window = int(0.2*width)
####width1 = int ( width/2. - common_window )
####width2 = int ( width/2. + common_window )
####img_sum = np.sum(textline_image[:,:,0], axis=0)
####sum_smoothed = gaussian_filter1d(img_sum, 3)
####peaks_real, _ = find_peaks(sum_smoothed, height=0)
####if len(peaks_real)>70:
####peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
####arg_sort = np.argsort(sum_smoothed[peaks_real])
####arg_sort4 =arg_sort[::-1][:4]
####peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
####argsort_sorted = np.argsort(peaks_sort_4)
####first_4_sorted = peaks_sort_4[argsort_sorted]
####y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
#####print(first_4_sorted,'first_4_sorted')
####arg_sortnew = np.argsort(y_4_sorted)
####peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] )
#####plt.figure(ind_tot)
#####plt.imshow(textline_image)
#####plt.plot([peaks_final[0], peaks_final[0]], [0, height-1])
#####plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
#####plt.savefig('./'+str(ind_tot)+'.png')
####return peaks_final[0], peaks_final[1]
####else:
####pass
##def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self, textline_image, ind_tot):
##width = np.shape(textline_image)[1]
##height = np.shape(textline_image)[0]
##common_window = int(0.06*width)
##width1 = int ( width/2. - common_window )
##width2 = int ( width/2. + common_window )
##img_sum = np.sum(textline_image[:,:,0], axis=0)
##sum_smoothed = gaussian_filter1d(img_sum, 3)
##peaks_real, _ = find_peaks(sum_smoothed, height=0)
##if len(peaks_real)>70:
###print(len(peaks_real), 'len(peaks_real)')
##peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
##arg_max = np.argmax(sum_smoothed[peaks_real])
##peaks_final = peaks_real[arg_max]
###plt.figure(ind_tot)
###plt.imshow(textline_image)
###plt.plot([peaks_final, peaks_final], [0, height-1])
####plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
###plt.savefig('./'+str(ind_tot)+'.png')
##return peaks_final
##else:
##return None
###def return_start_and_end_of_common_text_of_textline_ocr_new_splitted(
###self, peaks_real, sum_smoothed, start_split, end_split):
###peaks_real = peaks_real[(peaks_real<end_split) & (peaks_real>start_split)]
###arg_sort = np.argsort(sum_smoothed[peaks_real])
###arg_sort4 =arg_sort[::-1][:4]
###peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
###argsort_sorted = np.argsort(peaks_sort_4)
###first_4_sorted = peaks_sort_4[argsort_sorted]
###y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
####print(first_4_sorted,'first_4_sorted')
###arg_sortnew = np.argsort(y_4_sorted)
###peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] )
###return peaks_final[0]
###def return_start_and_end_of_common_text_of_textline_ocr_new(self, textline_image, ind_tot):
###width = np.shape(textline_image)[1]
###height = np.shape(textline_image)[0]
###common_window = int(0.15*width)
###width1 = int ( width/2. - common_window )
###width2 = int ( width/2. + common_window )
###mid = int(width/2.)
###img_sum = np.sum(textline_image[:,:,0], axis=0)
###sum_smoothed = gaussian_filter1d(img_sum, 3)
###peaks_real, _ = find_peaks(sum_smoothed, height=0)
###if len(peaks_real)>70:
###peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(
###peaks_real, sum_smoothed, width1, mid+2)
###peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(
###peaks_real, sum_smoothed, mid-2, width2)
####plt.figure(ind_tot)
####plt.imshow(textline_image)
####plt.plot([peak_start, peak_start], [0, height-1])
####plt.plot([peak_end, peak_end], [0, height-1])
####plt.savefig('./'+str(ind_tot)+'.png')
###return peak_start, peak_end
###else:
###pass
##def return_ocr_of_textline_without_common_section(
##self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
##if h2w_ratio > 0.05:
##pixel_values = processor(textline_image, return_tensors="pt").pixel_values
##generated_ids = model_ocr.generate(pixel_values.to(device))
##generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
##else:
###width = np.shape(textline_image)[1]
###height = np.shape(textline_image)[0]
###common_window = int(0.3*width)
###width1 = int ( width/2. - common_window )
###width2 = int ( width/2. + common_window )
##split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(
##textline_image, ind_tot)
##if split_point:
##image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height))
##image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height))
###pixel_values1 = processor(image1, return_tensors="pt").pixel_values
###pixel_values2 = processor(image2, return_tensors="pt").pixel_values
##pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values
##generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device))
##generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
###print(generated_text_merged,'generated_text_merged')
###generated_ids1 = model_ocr.generate(pixel_values1.to(device))
###generated_ids2 = model_ocr.generate(pixel_values2.to(device))
###generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
###generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
###generated_text = generated_text1 + ' ' + generated_text2
##generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1]
###print(generated_text1,'generated_text1')
###print(generated_text2, 'generated_text2')
###print('########################################')
##else:
##pixel_values = processor(textline_image, return_tensors="pt").pixel_values
##generated_ids = model_ocr.generate(pixel_values.to(device))
##generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
###print(generated_text,'generated_text')
###print('########################################')
##return generated_text
###def return_ocr_of_textline(
###self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
###if h2w_ratio > 0.05:
###pixel_values = processor(textline_image, return_tensors="pt").pixel_values
###generated_ids = model_ocr.generate(pixel_values.to(device))
###generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
###else:
####width = np.shape(textline_image)[1]
####height = np.shape(textline_image)[0]
####common_window = int(0.3*width)
####width1 = int ( width/2. - common_window )
####width2 = int ( width/2. + common_window )
###try:
###width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot)
###image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height))
###image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height))
###pixel_values1 = processor(image1, return_tensors="pt").pixel_values
###pixel_values2 = processor(image2, return_tensors="pt").pixel_values
###generated_ids1 = model_ocr.generate(pixel_values1.to(device))
###generated_ids2 = model_ocr.generate(pixel_values2.to(device))
###generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
###generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
####print(generated_text1,'generated_text1')
####print(generated_text2, 'generated_text2')
####print('########################################')
###match = sq(None, generated_text1, generated_text2).find_longest_match(
###0, len(generated_text1), 0, len(generated_text2))
###generated_text = generated_text1 + generated_text2[match.b+match.size:]
###except:
###pixel_values = processor(textline_image, return_tensors="pt").pixel_values
###generated_ids = model_ocr.generate(pixel_values.to(device))
###generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
###return generated_text
def return_list_of_contours_with_desired_order(self, ls_cons, sorted_indexes):
return [ls_cons[sorted_indexes[index]] for index in range(len(sorted_indexes))]

View file

@ -0,0 +1,756 @@
"""
Image enhancer. The output can be written as same scale of input or in new predicted scale.
"""
from logging import Logger
from difflib import SequenceMatcher as sq
from PIL import Image, ImageDraw, ImageFont
import math
import os
import sys
import time
from typing import Optional
import atexit
import warnings
from functools import partial
from pathlib import Path
from multiprocessing import cpu_count
import gc
import copy
from loky import ProcessPoolExecutor
import xml.etree.ElementTree as ET
import cv2
import numpy as np
from ocrd import OcrdPage
from ocrd_utils import getLogger, tf_disable_interactive_logs
import statistics
from tensorflow.keras.models import load_model
from .utils.resize import resize_image
from .utils import (
crop_image_inside_box
)
DPI_THRESHOLD = 298
KERNEL = np.ones((5, 5), np.uint8)
class Enhancer:
def __init__(
self,
dir_models : str,
dir_out : Optional[str] = None,
num_col_upper : Optional[int] = None,
num_col_lower : Optional[int] = None,
logger : Optional[Logger] = None,
):
self.dir_out = dir_out
self.input_binary = False
self.light_version = False
if num_col_upper:
self.num_col_upper = int(num_col_upper)
else:
self.num_col_upper = num_col_upper
if num_col_lower:
self.num_col_lower = int(num_col_lower)
else:
self.num_col_lower = num_col_lower
self.logger = logger if logger else getLogger('enhancement')
# for parallelization of CPU-intensive tasks:
self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200)
atexit.register(self.executor.shutdown)
self.dir_models = dir_models
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
self.model_page_dir = dir_models + "/eynollah-page-extraction_20210425"
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {}
t_c0 = time.time()
if image_filename:
ret['img'] = cv2.imread(image_filename)
if self.light_version:
self.dpi = 100
else:
self.dpi = 0#check_dpi(image_filename)
else:
ret['img'] = pil2cv(image_pil)
if self.light_version:
self.dpi = 100
else:
self.dpi = 0#check_dpi(image_pil)
ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY)
for prefix in ('', '_grayscale'):
ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8)
self._imgs = ret
if dpi is not None:
self.dpi = dpi
def reset_file_name_dir(self, image_filename):
t_c = time.time()
self.cache_images(image_filename=image_filename)
self.output_filename = os.path.join(self.dir_out, Path(image_filename).stem +'.png')
def imread(self, grayscale=False, uint8=True):
key = 'img'
if grayscale:
key += '_grayscale'
if uint8:
key += '_uint8'
return self._imgs[key].copy()
def isNaN(self, num):
return num != num
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement")
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < img_width_model:
img = cv2.resize(img, (img_height_model, img.shape[0]), interpolation=cv2.INTER_NEAREST)
margin = int(0.1 * img_width_model)
width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin
img = img / 255.
img_h = img.shape[0]
img_w = img.shape[1]
prediction_true = np.zeros((img_h, img_w, 3))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
for i in range(nxf):
for j in range(nyf):
if i == 0:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
else:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
if j == 0:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
else:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - img_width_model
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + 0:index_x_u - margin] = \
seg[0:-margin or None,
0:-margin or None]
elif i == nxf - 1 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + margin:index_x_u - 0] = \
seg[margin:,
margin:]
elif i == 0 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + 0:index_x_u - margin] = \
seg[margin:,
0:-margin or None]
elif i == nxf - 1 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + margin:index_x_u - 0] = \
seg[0:-margin or None,
margin:]
elif i == 0 and j != 0 and j != nyf - 1:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + 0:index_x_u - margin] = \
seg[margin:-margin or None,
0:-margin or None]
elif i == nxf - 1 and j != 0 and j != nyf - 1:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + margin:index_x_u - 0] = \
seg[margin:-margin or None,
margin:]
elif i != 0 and i != nxf - 1 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + margin:index_x_u - margin] = \
seg[0:-margin or None,
margin:-margin or None]
elif i != 0 and i != nxf - 1 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + margin:index_x_u - margin] = \
seg[margin:,
margin:-margin or None]
else:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + margin:index_x_u - margin] = \
seg[margin:-margin or None,
margin:-margin or None]
prediction_true = prediction_true.astype(int)
return prediction_true
def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred):
self.logger.debug("enter calculate_width_height_by_columns")
if num_col == 1 and width_early < 1100:
img_w_new = 2000
elif num_col == 1 and width_early >= 2500:
img_w_new = 2000
elif num_col == 1 and width_early >= 1100 and width_early < 2500:
img_w_new = width_early
elif num_col == 2 and width_early < 2000:
img_w_new = 2400
elif num_col == 2 and width_early >= 3500:
img_w_new = 2400
elif num_col == 2 and width_early >= 2000 and width_early < 3500:
img_w_new = width_early
elif num_col == 3 and width_early < 2000:
img_w_new = 3000
elif num_col == 3 and width_early >= 4000:
img_w_new = 3000
elif num_col == 3 and width_early >= 2000 and width_early < 4000:
img_w_new = width_early
elif num_col == 4 and width_early < 2500:
img_w_new = 4000
elif num_col == 4 and width_early >= 5000:
img_w_new = 4000
elif num_col == 4 and width_early >= 2500 and width_early < 5000:
img_w_new = width_early
elif num_col == 5 and width_early < 3700:
img_w_new = 5000
elif num_col == 5 and width_early >= 7000:
img_w_new = 5000
elif num_col == 5 and width_early >= 3700 and width_early < 7000:
img_w_new = width_early
elif num_col == 6 and width_early < 4500:
img_w_new = 6500 # 5400
else:
img_w_new = width_early
img_h_new = img_w_new * img.shape[0] // img.shape[1]
if label_p_pred[0][int(num_col - 1)] < 0.9 and img_w_new < width_early:
img_new = np.copy(img)
num_column_is_classified = False
#elif label_p_pred[0][int(num_col - 1)] < 0.8 and img_h_new >= 8000:
elif img_h_new >= 8000:
img_new = np.copy(img)
num_column_is_classified = False
else:
img_new = resize_image(img, img_h_new, img_w_new)
num_column_is_classified = True
return img_new, num_column_is_classified
def early_page_for_num_of_column_classification(self,img_bin):
self.logger.debug("enter early_page_for_num_of_column_classification")
if self.input_binary:
img = np.copy(img_bin).astype(np.uint8)
else:
img = self.imread()
img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
thresh = cv2.dilate(thresh, KERNEL, iterations=3)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
if len(contours)>0:
cnt_size = np.array([cv2.contourArea(contours[j])
for j in range(len(contours))])
cnt = contours[np.argmax(cnt_size)]
box = cv2.boundingRect(cnt)
else:
box = [0, 0, img.shape[1], img.shape[0]]
cropped_page, page_coord = crop_image_inside_box(box, img)
self.logger.debug("exit early_page_for_num_of_column_classification")
return cropped_page, page_coord
def calculate_width_height_by_columns_1_2(self, img, num_col, width_early, label_p_pred):
self.logger.debug("enter calculate_width_height_by_columns")
if num_col == 1:
img_w_new = 1000
else:
img_w_new = 1300
img_h_new = img_w_new * img.shape[0] // img.shape[1]
if label_p_pred[0][int(num_col - 1)] < 0.9 and img_w_new < width_early:
img_new = np.copy(img)
num_column_is_classified = False
#elif label_p_pred[0][int(num_col - 1)] < 0.8 and img_h_new >= 8000:
elif img_h_new >= 8000:
img_new = np.copy(img)
num_column_is_classified = False
else:
img_new = resize_image(img, img_h_new, img_w_new)
num_column_is_classified = True
return img_new, num_column_is_classified
def resize_and_enhance_image_with_column_classifier(self, light_version):
self.logger.debug("enter resize_and_enhance_image_with_column_classifier")
dpi = 0#self.dpi
self.logger.info("Detected %s DPI", dpi)
if self.input_binary:
img = self.imread()
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
img= np.copy(prediction_bin)
img_bin = prediction_bin
else:
img = self.imread()
self.h_org, self.w_org = img.shape[:2]
img_bin = None
width_early = img.shape[1]
t1 = time.time()
_, page_coord = self.early_page_for_num_of_column_classification(img_bin)
self.image_page_org_size = img[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3], :]
self.page_coord = page_coord
if self.num_col_upper and not self.num_col_lower:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
elif self.num_col_lower and not self.num_col_upper:
num_col = self.num_col_lower
label_p_pred = [np.ones(6)]
elif not self.num_col_upper and not self.num_col_lower:
if self.input_binary:
img_in = np.copy(img)
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = img_in.reshape(1, 448, 448, 3)
else:
img_1ch = self.imread(grayscale=True)
width_early = img_1ch.shape[1]
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
if self.input_binary:
img_in = np.copy(img)
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = img_in.reshape(1, 448, 448, 3)
else:
img_1ch = self.imread(grayscale=True)
width_early = img_1ch.shape[1]
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
if num_col > self.num_col_upper:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
if num_col < self.num_col_lower:
num_col = self.num_col_lower
label_p_pred = [np.ones(6)]
else:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))
if dpi < DPI_THRESHOLD:
if light_version and num_col in (1,2):
img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2(
img, num_col, width_early, label_p_pred)
else:
img_new, num_column_is_classified = self.calculate_width_height_by_columns(
img, num_col, width_early, label_p_pred)
if light_version:
image_res = np.copy(img_new)
else:
image_res = self.predict_enhancement(img_new)
is_image_enhanced = True
else:
num_column_is_classified = True
image_res = np.copy(img)
is_image_enhanced = False
self.logger.debug("exit resize_and_enhance_image_with_column_classifier")
return is_image_enhanced, img, image_res, num_col, num_column_is_classified, img_bin
def do_prediction(
self, patches, img, model,
n_batch_inference=1, marginal_of_patch_percent=0.1,
thresholding_for_some_classes_in_light_version=False,
thresholding_for_artificial_class_in_light_version=False, thresholding_for_fl_light_version=False, threshold_art_class_textline=0.1):
self.logger.debug("enter do_prediction")
img_height_model = model.layers[-1].output_shape[1]
img_width_model = model.layers[-1].output_shape[2]
if not patches:
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / float(255.0)
img = resize_image(img, img_height_model, img_width_model)
label_p_pred = model.predict(img[np.newaxis], verbose=0)
seg = np.argmax(label_p_pred, axis=3)[0]
if thresholding_for_artificial_class_in_light_version:
seg_art = label_p_pred[0,:,:,2]
seg_art[seg_art<threshold_art_class_textline] = 0
seg_art[seg_art>0] =1
skeleton_art = skeletonize(seg_art)
skeleton_art = skeleton_art*1
seg[skeleton_art==1]=2
if thresholding_for_fl_light_version:
seg_header = label_p_pred[0,:,:,2]
seg_header[seg_header<0.2] = 0
seg_header[seg_header>0] =1
seg[seg_header==1]=2
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8)
return prediction_true
if img.shape[0] < img_height_model:
img = resize_image(img, img_height_model, img.shape[1])
if img.shape[1] < img_width_model:
img = resize_image(img, img.shape[0], img_width_model)
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
margin = int(marginal_of_patch_percent * img_height_model)
width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin
img = img / 255.
#img = img.astype(np.float16)
img_h = img.shape[0]
img_w = img.shape[1]
prediction_true = np.zeros((img_h, img_w, 3))
mask_true = np.zeros((img_h, img_w))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3))
for i in range(nxf):
for j in range(nyf):
if i == 0:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
else:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
if j == 0:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
else:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - img_width_model
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - img_height_model
list_i_s.append(i)
list_j_s.append(j)
list_x_u.append(index_x_u)
list_x_d.append(index_x_d)
list_y_d.append(index_y_d)
list_y_u.append(index_y_u)
img_patch[batch_indexer,:,:,:] = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
batch_indexer += 1
if (batch_indexer == n_batch_inference or
# last batch
i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch, verbose=0)
seg = np.argmax(label_p_pred, axis=3)
if thresholding_for_some_classes_in_light_version:
seg_not_base = label_p_pred[:,:,:,4]
seg_not_base[seg_not_base>0.03] =1
seg_not_base[seg_not_base<1] =0
seg_line = label_p_pred[:,:,:,3]
seg_line[seg_line>0.1] =1
seg_line[seg_line<1] =0
seg_background = label_p_pred[:,:,:,0]
seg_background[seg_background>0.25] =1
seg_background[seg_background<1] =0
seg[seg_not_base==1]=4
seg[seg_background==1]=0
seg[(seg_line==1) & (seg==0)]=3
if thresholding_for_artificial_class_in_light_version:
seg_art = label_p_pred[:,:,:,2]
seg_art[seg_art<threshold_art_class_textline] = 0
seg_art[seg_art>0] =1
##seg[seg_art==1]=2
indexer_inside_batch = 0
for i_batch, j_batch in zip(list_i_s, list_j_s):
seg_in = seg[indexer_inside_batch]
if thresholding_for_artificial_class_in_light_version:
seg_in_art = seg_art[indexer_inside_batch]
index_y_u_in = list_y_u[indexer_inside_batch]
index_y_d_in = list_y_d[indexer_inside_batch]
index_x_u_in = list_x_u[indexer_inside_batch]
index_x_d_in = list_x_d[indexer_inside_batch]
if i_batch == 0 and j_batch == 0:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin] = \
seg_in[0:-margin or None,
0:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin, 1] = \
seg_in_art[0:-margin or None,
0:-margin or None]
elif i_batch == nxf - 1 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[margin:,
margin:,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - 0, 1] = \
seg_in_art[margin:,
margin:]
elif i_batch == 0 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + 0:index_x_u_in - margin] = \
seg_in[margin:,
0:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + 0:index_x_u_in - margin, 1] = \
seg_in_art[margin:,
0:-margin or None]
elif i_batch == nxf - 1 and j_batch == 0:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[0:-margin or None,
margin:,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0, 1] = \
seg_in_art[0:-margin or None,
margin:]
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin] = \
seg_in[margin:-margin or None,
0:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + 0:index_x_u_in - margin, 1] = \
seg_in_art[margin:-margin or None,
0:-margin or None]
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0] = \
seg_in[margin:-margin or None,
margin:,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - 0, 1] = \
seg_in_art[margin:-margin or None,
margin:]
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[0:-margin or None,
margin:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin, 1] = \
seg_in_art[0:-margin or None,
margin:-margin or None]
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[margin:,
margin:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
index_x_d_in + margin:index_x_u_in - margin, 1] = \
seg_in_art[margin:,
margin:-margin or None]
else:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin] = \
seg_in[margin:-margin or None,
margin:-margin or None,
np.newaxis]
if thresholding_for_artificial_class_in_light_version:
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
index_x_d_in + margin:index_x_u_in - margin, 1] = \
seg_in_art[margin:-margin or None,
margin:-margin or None]
indexer_inside_batch += 1
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch[:] = 0
prediction_true = prediction_true.astype(np.uint8)
if thresholding_for_artificial_class_in_light_version:
kernel_min = np.ones((3, 3), np.uint8)
prediction_true[:,:,0][prediction_true[:,:,0]==2] = 0
skeleton_art = skeletonize(prediction_true[:,:,1])
skeleton_art = skeleton_art*1
skeleton_art = skeleton_art.astype('uint8')
skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1)
prediction_true[:,:,0][skeleton_art==1]=2
#del model
gc.collect()
return prediction_true
def run_enhancement(self, light_version):
t_in = time.time()
self.logger.info("Resizing and enhancing image...")
is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \
self.resize_and_enhance_image_with_column_classifier(light_version)
self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ')
return img_res, is_image_enhanced, num_col_classifier, num_column_is_classified
def run_single(self):
t0 = time.time()
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False)
return img_res
def run(self, image_filename : Optional[str] = None, dir_in : Optional[str] = None, overwrite : bool = False):
"""
Get image and scales, then extract the page of scanned image
"""
self.logger.debug("enter run")
t0_tot = time.time()
if dir_in:
self.ls_imgs = os.listdir(dir_in)
elif image_filename:
self.ls_imgs = [image_filename]
else:
raise ValueError("run requires either a single image filename or a directory")
for img_filename in self.ls_imgs:
self.logger.info(img_filename)
t0 = time.time()
self.reset_file_name_dir(os.path.join(dir_in or "", img_filename))
#print("text region early -11 in %.1fs", time.time() - t0)
if os.path.exists(self.output_filename):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", self.output_filename)
else:
self.logger.warning("will skip input for existing output file '%s'", self.output_filename)
continue
image_enhanced = self.run_single()
img_enhanced_org_scale = resize_image(image_enhanced, self.h_org, self.w_org)
cv2.imwrite(self.output_filename, img_enhanced_org_scale)