From cc36694dfdab852e27780187f15da1155423bd02 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Sun, 1 Jun 2025 15:53:04 +0200 Subject: [PATCH] image enhancer is integrated --- src/eynollah/cli.py | 69 +++ src/eynollah/eynollah.py | 234 +--------- src/eynollah/image_enhancer.py | 756 +++++++++++++++++++++++++++++++++ 3 files changed, 830 insertions(+), 229 deletions(-) create mode 100644 src/eynollah/image_enhancer.py diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index 2d0d6f9..840bc4b 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -3,6 +3,7 @@ import click from ocrd_utils import initLogging, getLevelName, getLogger from eynollah.eynollah import Eynollah, Eynollah_ocr from eynollah.sbb_binarize import SbbBinarizer +from eynollah.image_enhancer import Enhancer @click.group() def main(): @@ -70,6 +71,74 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out) +@main.command() +@click.option( + "--image", + "-i", + help="image filename", + type=click.Path(exists=True, dir_okay=False), +) + +@click.option( + "--out", + "-o", + help="directory to write output xml data", + type=click.Path(exists=True, file_okay=False), + required=True, +) +@click.option( + "--overwrite", + "-O", + help="overwrite (instead of skipping) if output xml exists", + is_flag=True, +) +@click.option( + "--dir_in", + "-di", + help="directory of images", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--model", + "-m", + help="directory of models", + type=click.Path(exists=True, file_okay=False), + required=True, +) + +@click.option( + "--num_col_upper", + "-ncu", + help="lower limit of columns in document image", +) +@click.option( + "--num_col_lower", + "-ncl", + help="upper limit of columns in document image", +) +@click.option( + "--log_level", + "-l", + type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), + help="Override log level globally to this", +) + +def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_lower, log_level): + initLogging() + if log_level: + getLogger('enhancement').setLevel(getLevelName(log_level)) + assert image or dir_in, "Either a single image -i or a dir_in -di is required" + enhancer_object = Enhancer( + model, + logger=getLogger('enhancement'), + dir_out=out, + num_col_upper=num_col_upper, + num_col_lower=num_col_lower, + ) + if dir_in: + enhancer_object.run(dir_in=dir_in, overwrite=overwrite) + else: + enhancer_object.run(image_filename=image, overwrite=overwrite) @main.command() @click.option( diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 6c00329..cf540d3 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -3612,25 +3612,12 @@ class Eynollah: inference_bs = 3 - cv2.imwrite('textregions.png', text_regions_p*50) - cv2.imwrite('sep.png', (text_regions_p[:,:]==6)*255) - ver_kernel = np.ones((5, 1), dtype=np.uint8) hor_kernel = np.ones((1, 5), dtype=np.uint8) - - #separators = (text_regions_p[:,:]==6)*1 - #text_regions_p[text_regions_p[:,:]==6] = 0 - #separators = separators.astype('uint8') - - #separators = cv2.erode(separators , hor_kernel, iterations=1) - #text_regions_p[separators[:,:]==1] = 6 - - #cv2.imwrite('sep_new.png', (text_regions_p[:,:]==6)*255) - min_cont_size_to_be_dilated = 10 - if len(contours_only_text_parent)>min_cont_size_to_be_dilated: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _ = find_new_features_of_contours(contours_only_text_parent) args_cont_located = np.array(range(len(contours_only_text_parent))) @@ -3672,7 +3659,6 @@ class Eynollah: text_regions_p_textregions_dilated = cv2.dilate(text_regions_p_textregions_dilated , ver_kernel, iterations=5) text_regions_p_textregions_dilated[text_regions_p[:,:]>1] = 0 - cv2.imwrite('text_regions_p_textregions_dilated.png', text_regions_p_textregions_dilated*255) contours_only_dilated, hir_on_text_dilated = return_contours_of_image(text_regions_p_textregions_dilated) contours_only_dilated = return_parent_contours(contours_only_dilated, hir_on_text_dilated) @@ -3723,21 +3709,20 @@ class Eynollah: img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12, int(x_min_main[j]):int(x_max_main[j])] = 1 co_text_all_org = contours_only_text_parent + contours_only_text_parent_h - if len(contours_only_text_parent)>min_cont_size_to_be_dilated: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: co_text_all = contours_only_dilated + contours_only_text_parent_h else: co_text_all = contours_only_text_parent + contours_only_text_parent_h else: co_text_all_org = contours_only_text_parent - if len(contours_only_text_parent)>min_cont_size_to_be_dilated: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: co_text_all = contours_only_dilated else: co_text_all = contours_only_text_parent if not len(co_text_all): return [], [] - print(len(co_text_all), "co_text_all") - print(len(co_text_all_org), "co_text_all_org") + labels_con = np.zeros((int(y_len /6.), int(x_len/6.), len(co_text_all)), dtype=bool) co_text_all = [(i/6).astype(int) for i in co_text_all] for i in range(len(co_text_all)): @@ -3805,7 +3790,7 @@ class Eynollah: ordered = [i[0] for i in ordered] - if len(contours_only_text_parent)>min_cont_size_to_be_dilated: + if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version: org_contours_indexes = [] for ind in range(len(ordered)): region_with_curr_order = ordered[ind] @@ -3823,215 +3808,6 @@ class Eynollah: else: region_ids = ['region_%04d' % i for i in range(len(co_text_all_org))] return ordered, region_ids - - - ####def return_start_and_end_of_common_text_of_textline_ocr(self, textline_image, ind_tot): - ####width = np.shape(textline_image)[1] - ####height = np.shape(textline_image)[0] - ####common_window = int(0.2*width) - - ####width1 = int ( width/2. - common_window ) - ####width2 = int ( width/2. + common_window ) - - ####img_sum = np.sum(textline_image[:,:,0], axis=0) - ####sum_smoothed = gaussian_filter1d(img_sum, 3) - - ####peaks_real, _ = find_peaks(sum_smoothed, height=0) - ####if len(peaks_real)>70: - - ####peaks_real = peaks_real[(peaks_realwidth1)] - - ####arg_sort = np.argsort(sum_smoothed[peaks_real]) - ####arg_sort4 =arg_sort[::-1][:4] - ####peaks_sort_4 = peaks_real[arg_sort][::-1][:4] - ####argsort_sorted = np.argsort(peaks_sort_4) - - ####first_4_sorted = peaks_sort_4[argsort_sorted] - ####y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]] - #####print(first_4_sorted,'first_4_sorted') - - ####arg_sortnew = np.argsort(y_4_sorted) - ####peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] ) - - #####plt.figure(ind_tot) - #####plt.imshow(textline_image) - #####plt.plot([peaks_final[0], peaks_final[0]], [0, height-1]) - #####plt.plot([peaks_final[1], peaks_final[1]], [0, height-1]) - #####plt.savefig('./'+str(ind_tot)+'.png') - - ####return peaks_final[0], peaks_final[1] - ####else: - ####pass - - ##def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self, textline_image, ind_tot): - ##width = np.shape(textline_image)[1] - ##height = np.shape(textline_image)[0] - ##common_window = int(0.06*width) - - ##width1 = int ( width/2. - common_window ) - ##width2 = int ( width/2. + common_window ) - - ##img_sum = np.sum(textline_image[:,:,0], axis=0) - ##sum_smoothed = gaussian_filter1d(img_sum, 3) - - ##peaks_real, _ = find_peaks(sum_smoothed, height=0) - ##if len(peaks_real)>70: - ###print(len(peaks_real), 'len(peaks_real)') - - ##peaks_real = peaks_real[(peaks_realwidth1)] - - ##arg_max = np.argmax(sum_smoothed[peaks_real]) - ##peaks_final = peaks_real[arg_max] - - ###plt.figure(ind_tot) - ###plt.imshow(textline_image) - ###plt.plot([peaks_final, peaks_final], [0, height-1]) - ####plt.plot([peaks_final[1], peaks_final[1]], [0, height-1]) - ###plt.savefig('./'+str(ind_tot)+'.png') - - ##return peaks_final - ##else: - ##return None - - ###def return_start_and_end_of_common_text_of_textline_ocr_new_splitted( - ###self, peaks_real, sum_smoothed, start_split, end_split): - - ###peaks_real = peaks_real[(peaks_realstart_split)] - - ###arg_sort = np.argsort(sum_smoothed[peaks_real]) - ###arg_sort4 =arg_sort[::-1][:4] - ###peaks_sort_4 = peaks_real[arg_sort][::-1][:4] - ###argsort_sorted = np.argsort(peaks_sort_4) - - ###first_4_sorted = peaks_sort_4[argsort_sorted] - ###y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]] - ####print(first_4_sorted,'first_4_sorted') - - ###arg_sortnew = np.argsort(y_4_sorted) - ###peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] ) - ###return peaks_final[0] - - ###def return_start_and_end_of_common_text_of_textline_ocr_new(self, textline_image, ind_tot): - ###width = np.shape(textline_image)[1] - ###height = np.shape(textline_image)[0] - ###common_window = int(0.15*width) - - ###width1 = int ( width/2. - common_window ) - ###width2 = int ( width/2. + common_window ) - ###mid = int(width/2.) - - ###img_sum = np.sum(textline_image[:,:,0], axis=0) - ###sum_smoothed = gaussian_filter1d(img_sum, 3) - - ###peaks_real, _ = find_peaks(sum_smoothed, height=0) - ###if len(peaks_real)>70: - ###peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted( - ###peaks_real, sum_smoothed, width1, mid+2) - ###peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted( - ###peaks_real, sum_smoothed, mid-2, width2) - - ####plt.figure(ind_tot) - ####plt.imshow(textline_image) - ####plt.plot([peak_start, peak_start], [0, height-1]) - ####plt.plot([peak_end, peak_end], [0, height-1]) - ####plt.savefig('./'+str(ind_tot)+'.png') - - ###return peak_start, peak_end - ###else: - ###pass - - ##def return_ocr_of_textline_without_common_section( - ##self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot): - - ##if h2w_ratio > 0.05: - ##pixel_values = processor(textline_image, return_tensors="pt").pixel_values - ##generated_ids = model_ocr.generate(pixel_values.to(device)) - ##generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - ##else: - ###width = np.shape(textline_image)[1] - ###height = np.shape(textline_image)[0] - ###common_window = int(0.3*width) - ###width1 = int ( width/2. - common_window ) - ###width2 = int ( width/2. + common_window ) - - ##split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section( - ##textline_image, ind_tot) - ##if split_point: - ##image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height)) - ##image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height)) - - ###pixel_values1 = processor(image1, return_tensors="pt").pixel_values - ###pixel_values2 = processor(image2, return_tensors="pt").pixel_values - - ##pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values - ##generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device)) - ##generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True) - - ###print(generated_text_merged,'generated_text_merged') - - ###generated_ids1 = model_ocr.generate(pixel_values1.to(device)) - ###generated_ids2 = model_ocr.generate(pixel_values2.to(device)) - - ###generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0] - ###generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0] - - ###generated_text = generated_text1 + ' ' + generated_text2 - ##generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1] - - ###print(generated_text1,'generated_text1') - ###print(generated_text2, 'generated_text2') - ###print('########################################') - ##else: - ##pixel_values = processor(textline_image, return_tensors="pt").pixel_values - ##generated_ids = model_ocr.generate(pixel_values.to(device)) - ##generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - ###print(generated_text,'generated_text') - ###print('########################################') - ##return generated_text - - ###def return_ocr_of_textline( - ###self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot): - - ###if h2w_ratio > 0.05: - ###pixel_values = processor(textline_image, return_tensors="pt").pixel_values - ###generated_ids = model_ocr.generate(pixel_values.to(device)) - ###generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - ###else: - ####width = np.shape(textline_image)[1] - ####height = np.shape(textline_image)[0] - ####common_window = int(0.3*width) - ####width1 = int ( width/2. - common_window ) - ####width2 = int ( width/2. + common_window ) - - ###try: - ###width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot) - - ###image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height)) - ###image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height)) - - ###pixel_values1 = processor(image1, return_tensors="pt").pixel_values - ###pixel_values2 = processor(image2, return_tensors="pt").pixel_values - - ###generated_ids1 = model_ocr.generate(pixel_values1.to(device)) - ###generated_ids2 = model_ocr.generate(pixel_values2.to(device)) - - ###generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0] - ###generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0] - ####print(generated_text1,'generated_text1') - ####print(generated_text2, 'generated_text2') - ####print('########################################') - - ###match = sq(None, generated_text1, generated_text2).find_longest_match( - ###0, len(generated_text1), 0, len(generated_text2)) - ###generated_text = generated_text1 + generated_text2[match.b+match.size:] - ###except: - ###pixel_values = processor(textline_image, return_tensors="pt").pixel_values - ###generated_ids = model_ocr.generate(pixel_values.to(device)) - ###generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - ###return generated_text - def return_list_of_contours_with_desired_order(self, ls_cons, sorted_indexes): return [ls_cons[sorted_indexes[index]] for index in range(len(sorted_indexes))] diff --git a/src/eynollah/image_enhancer.py b/src/eynollah/image_enhancer.py new file mode 100644 index 0000000..71445f7 --- /dev/null +++ b/src/eynollah/image_enhancer.py @@ -0,0 +1,756 @@ +""" +Image enhancer. The output can be written as same scale of input or in new predicted scale. +""" + +from logging import Logger +from difflib import SequenceMatcher as sq +from PIL import Image, ImageDraw, ImageFont +import math +import os +import sys +import time +from typing import Optional +import atexit +import warnings +from functools import partial +from pathlib import Path +from multiprocessing import cpu_count +import gc +import copy +from loky import ProcessPoolExecutor +import xml.etree.ElementTree as ET +import cv2 +import numpy as np +from ocrd import OcrdPage +from ocrd_utils import getLogger, tf_disable_interactive_logs +import statistics +from tensorflow.keras.models import load_model +from .utils.resize import resize_image +from .utils import ( + crop_image_inside_box +) + +DPI_THRESHOLD = 298 +KERNEL = np.ones((5, 5), np.uint8) + + +class Enhancer: + def __init__( + self, + dir_models : str, + dir_out : Optional[str] = None, + num_col_upper : Optional[int] = None, + num_col_lower : Optional[int] = None, + logger : Optional[Logger] = None, + ): + self.dir_out = dir_out + self.input_binary = False + self.light_version = False + if num_col_upper: + self.num_col_upper = int(num_col_upper) + else: + self.num_col_upper = num_col_upper + if num_col_lower: + self.num_col_lower = int(num_col_lower) + else: + self.num_col_lower = num_col_lower + + self.logger = logger if logger else getLogger('enhancement') + # for parallelization of CPU-intensive tasks: + self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200) + atexit.register(self.executor.shutdown) + self.dir_models = dir_models + self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425" + self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425" + self.model_page_dir = dir_models + "/eynollah-page-extraction_20210425" + + try: + for device in tf.config.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(device, True) + except: + self.logger.warning("no GPU device available") + + self.model_page = self.our_load_model(self.model_page_dir) + self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) + self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) + + def cache_images(self, image_filename=None, image_pil=None, dpi=None): + ret = {} + t_c0 = time.time() + if image_filename: + ret['img'] = cv2.imread(image_filename) + if self.light_version: + self.dpi = 100 + else: + self.dpi = 0#check_dpi(image_filename) + else: + ret['img'] = pil2cv(image_pil) + if self.light_version: + self.dpi = 100 + else: + self.dpi = 0#check_dpi(image_pil) + ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY) + for prefix in ('', '_grayscale'): + ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8) + self._imgs = ret + if dpi is not None: + self.dpi = dpi + + def reset_file_name_dir(self, image_filename): + t_c = time.time() + self.cache_images(image_filename=image_filename) + self.output_filename = os.path.join(self.dir_out, Path(image_filename).stem +'.png') + + def imread(self, grayscale=False, uint8=True): + key = 'img' + if grayscale: + key += '_grayscale' + if uint8: + key += '_uint8' + return self._imgs[key].copy() + + def isNaN(self, num): + return num != num + + @staticmethod + def our_load_model(model_file): + if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): + # prefer SavedModel over HDF5 format if it exists + model_file = model_file[:-3] + try: + model = load_model(model_file, compile=False) + except: + model = load_model(model_file, compile=False, custom_objects={ + "PatchEncoder": PatchEncoder, "Patches": Patches}) + return model + + def predict_enhancement(self, img): + self.logger.debug("enter predict_enhancement") + + img_height_model = self.model_enhancement.layers[-1].output_shape[1] + img_width_model = self.model_enhancement.layers[-1].output_shape[2] + if img.shape[0] < img_height_model: + img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) + if img.shape[1] < img_width_model: + img = cv2.resize(img, (img_height_model, img.shape[0]), interpolation=cv2.INTER_NEAREST) + margin = int(0.1 * img_width_model) + width_mid = img_width_model - 2 * margin + height_mid = img_height_model - 2 * margin + img = img / 255. + img_h = img.shape[0] + img_w = img.shape[1] + + prediction_true = np.zeros((img_h, img_w, 3)) + nxf = img_w / float(width_mid) + nyf = img_h / float(height_mid) + nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf) + nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf) + + for i in range(nxf): + for j in range(nyf): + if i == 0: + index_x_d = i * width_mid + index_x_u = index_x_d + img_width_model + else: + index_x_d = i * width_mid + index_x_u = index_x_d + img_width_model + if j == 0: + index_y_d = j * height_mid + index_y_u = index_y_d + img_height_model + else: + index_y_d = j * height_mid + index_y_u = index_y_d + img_height_model + + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - img_width_model + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - img_height_model + + img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] + label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) + seg = label_p_pred[0, :, :, :] * 255 + + if i == 0 and j == 0: + prediction_true[index_y_d + 0:index_y_u - margin, + index_x_d + 0:index_x_u - margin] = \ + seg[0:-margin or None, + 0:-margin or None] + elif i == nxf - 1 and j == nyf - 1: + prediction_true[index_y_d + margin:index_y_u - 0, + index_x_d + margin:index_x_u - 0] = \ + seg[margin:, + margin:] + elif i == 0 and j == nyf - 1: + prediction_true[index_y_d + margin:index_y_u - 0, + index_x_d + 0:index_x_u - margin] = \ + seg[margin:, + 0:-margin or None] + elif i == nxf - 1 and j == 0: + prediction_true[index_y_d + 0:index_y_u - margin, + index_x_d + margin:index_x_u - 0] = \ + seg[0:-margin or None, + margin:] + elif i == 0 and j != 0 and j != nyf - 1: + prediction_true[index_y_d + margin:index_y_u - margin, + index_x_d + 0:index_x_u - margin] = \ + seg[margin:-margin or None, + 0:-margin or None] + elif i == nxf - 1 and j != 0 and j != nyf - 1: + prediction_true[index_y_d + margin:index_y_u - margin, + index_x_d + margin:index_x_u - 0] = \ + seg[margin:-margin or None, + margin:] + elif i != 0 and i != nxf - 1 and j == 0: + prediction_true[index_y_d + 0:index_y_u - margin, + index_x_d + margin:index_x_u - margin] = \ + seg[0:-margin or None, + margin:-margin or None] + elif i != 0 and i != nxf - 1 and j == nyf - 1: + prediction_true[index_y_d + margin:index_y_u - 0, + index_x_d + margin:index_x_u - margin] = \ + seg[margin:, + margin:-margin or None] + else: + prediction_true[index_y_d + margin:index_y_u - margin, + index_x_d + margin:index_x_u - margin] = \ + seg[margin:-margin or None, + margin:-margin or None] + + prediction_true = prediction_true.astype(int) + return prediction_true + + def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred): + self.logger.debug("enter calculate_width_height_by_columns") + if num_col == 1 and width_early < 1100: + img_w_new = 2000 + elif num_col == 1 and width_early >= 2500: + img_w_new = 2000 + elif num_col == 1 and width_early >= 1100 and width_early < 2500: + img_w_new = width_early + elif num_col == 2 and width_early < 2000: + img_w_new = 2400 + elif num_col == 2 and width_early >= 3500: + img_w_new = 2400 + elif num_col == 2 and width_early >= 2000 and width_early < 3500: + img_w_new = width_early + elif num_col == 3 and width_early < 2000: + img_w_new = 3000 + elif num_col == 3 and width_early >= 4000: + img_w_new = 3000 + elif num_col == 3 and width_early >= 2000 and width_early < 4000: + img_w_new = width_early + elif num_col == 4 and width_early < 2500: + img_w_new = 4000 + elif num_col == 4 and width_early >= 5000: + img_w_new = 4000 + elif num_col == 4 and width_early >= 2500 and width_early < 5000: + img_w_new = width_early + elif num_col == 5 and width_early < 3700: + img_w_new = 5000 + elif num_col == 5 and width_early >= 7000: + img_w_new = 5000 + elif num_col == 5 and width_early >= 3700 and width_early < 7000: + img_w_new = width_early + elif num_col == 6 and width_early < 4500: + img_w_new = 6500 # 5400 + else: + img_w_new = width_early + img_h_new = img_w_new * img.shape[0] // img.shape[1] + + if label_p_pred[0][int(num_col - 1)] < 0.9 and img_w_new < width_early: + img_new = np.copy(img) + num_column_is_classified = False + #elif label_p_pred[0][int(num_col - 1)] < 0.8 and img_h_new >= 8000: + elif img_h_new >= 8000: + img_new = np.copy(img) + num_column_is_classified = False + else: + img_new = resize_image(img, img_h_new, img_w_new) + num_column_is_classified = True + + return img_new, num_column_is_classified + + def early_page_for_num_of_column_classification(self,img_bin): + self.logger.debug("enter early_page_for_num_of_column_classification") + if self.input_binary: + img = np.copy(img_bin).astype(np.uint8) + else: + img = self.imread() + img = cv2.GaussianBlur(img, (5, 5), 0) + img_page_prediction = self.do_prediction(False, img, self.model_page) + + imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(imgray, 0, 255, 0) + thresh = cv2.dilate(thresh, KERNEL, iterations=3) + contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + if len(contours)>0: + cnt_size = np.array([cv2.contourArea(contours[j]) + for j in range(len(contours))]) + cnt = contours[np.argmax(cnt_size)] + box = cv2.boundingRect(cnt) + else: + box = [0, 0, img.shape[1], img.shape[0]] + cropped_page, page_coord = crop_image_inside_box(box, img) + + self.logger.debug("exit early_page_for_num_of_column_classification") + return cropped_page, page_coord + + def calculate_width_height_by_columns_1_2(self, img, num_col, width_early, label_p_pred): + self.logger.debug("enter calculate_width_height_by_columns") + if num_col == 1: + img_w_new = 1000 + else: + img_w_new = 1300 + img_h_new = img_w_new * img.shape[0] // img.shape[1] + + if label_p_pred[0][int(num_col - 1)] < 0.9 and img_w_new < width_early: + img_new = np.copy(img) + num_column_is_classified = False + #elif label_p_pred[0][int(num_col - 1)] < 0.8 and img_h_new >= 8000: + elif img_h_new >= 8000: + img_new = np.copy(img) + num_column_is_classified = False + else: + img_new = resize_image(img, img_h_new, img_w_new) + num_column_is_classified = True + + return img_new, num_column_is_classified + + def resize_and_enhance_image_with_column_classifier(self, light_version): + self.logger.debug("enter resize_and_enhance_image_with_column_classifier") + dpi = 0#self.dpi + self.logger.info("Detected %s DPI", dpi) + if self.input_binary: + img = self.imread() + prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5) + prediction_bin = 255 * (prediction_bin[:,:,0]==0) + prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) + img= np.copy(prediction_bin) + img_bin = prediction_bin + else: + img = self.imread() + self.h_org, self.w_org = img.shape[:2] + img_bin = None + + width_early = img.shape[1] + t1 = time.time() + _, page_coord = self.early_page_for_num_of_column_classification(img_bin) + + self.image_page_org_size = img[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3], :] + self.page_coord = page_coord + + if self.num_col_upper and not self.num_col_lower: + num_col = self.num_col_upper + label_p_pred = [np.ones(6)] + elif self.num_col_lower and not self.num_col_upper: + num_col = self.num_col_lower + label_p_pred = [np.ones(6)] + elif not self.num_col_upper and not self.num_col_lower: + if self.input_binary: + img_in = np.copy(img) + img_in = img_in / 255.0 + img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST) + img_in = img_in.reshape(1, 448, 448, 3) + else: + img_1ch = self.imread(grayscale=True) + width_early = img_1ch.shape[1] + img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] + + img_1ch = img_1ch / 255.0 + img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST) + img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) + img_in[0, :, :, 0] = img_1ch[:, :] + img_in[0, :, :, 1] = img_1ch[:, :] + img_in[0, :, :, 2] = img_1ch[:, :] + + label_p_pred = self.model_classifier.predict(img_in, verbose=0) + num_col = np.argmax(label_p_pred[0]) + 1 + elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): + if self.input_binary: + img_in = np.copy(img) + img_in = img_in / 255.0 + img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST) + img_in = img_in.reshape(1, 448, 448, 3) + else: + img_1ch = self.imread(grayscale=True) + width_early = img_1ch.shape[1] + img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] + + img_1ch = img_1ch / 255.0 + img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST) + img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) + img_in[0, :, :, 0] = img_1ch[:, :] + img_in[0, :, :, 1] = img_1ch[:, :] + img_in[0, :, :, 2] = img_1ch[:, :] + + label_p_pred = self.model_classifier.predict(img_in, verbose=0) + num_col = np.argmax(label_p_pred[0]) + 1 + + if num_col > self.num_col_upper: + num_col = self.num_col_upper + label_p_pred = [np.ones(6)] + if num_col < self.num_col_lower: + num_col = self.num_col_lower + label_p_pred = [np.ones(6)] + else: + num_col = self.num_col_upper + label_p_pred = [np.ones(6)] + + self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5)) + + if dpi < DPI_THRESHOLD: + if light_version and num_col in (1,2): + img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2( + img, num_col, width_early, label_p_pred) + else: + img_new, num_column_is_classified = self.calculate_width_height_by_columns( + img, num_col, width_early, label_p_pred) + if light_version: + image_res = np.copy(img_new) + else: + image_res = self.predict_enhancement(img_new) + is_image_enhanced = True + + else: + num_column_is_classified = True + image_res = np.copy(img) + is_image_enhanced = False + + self.logger.debug("exit resize_and_enhance_image_with_column_classifier") + return is_image_enhanced, img, image_res, num_col, num_column_is_classified, img_bin + def do_prediction( + self, patches, img, model, + n_batch_inference=1, marginal_of_patch_percent=0.1, + thresholding_for_some_classes_in_light_version=False, + thresholding_for_artificial_class_in_light_version=False, thresholding_for_fl_light_version=False, threshold_art_class_textline=0.1): + + self.logger.debug("enter do_prediction") + img_height_model = model.layers[-1].output_shape[1] + img_width_model = model.layers[-1].output_shape[2] + + if not patches: + img_h_page = img.shape[0] + img_w_page = img.shape[1] + img = img / float(255.0) + img = resize_image(img, img_height_model, img_width_model) + + label_p_pred = model.predict(img[np.newaxis], verbose=0) + seg = np.argmax(label_p_pred, axis=3)[0] + + if thresholding_for_artificial_class_in_light_version: + seg_art = label_p_pred[0,:,:,2] + + seg_art[seg_art0] =1 + + skeleton_art = skeletonize(seg_art) + skeleton_art = skeleton_art*1 + + seg[skeleton_art==1]=2 + + if thresholding_for_fl_light_version: + seg_header = label_p_pred[0,:,:,2] + + seg_header[seg_header<0.2] = 0 + seg_header[seg_header>0] =1 + + seg[seg_header==1]=2 + + seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8) + return prediction_true + + if img.shape[0] < img_height_model: + img = resize_image(img, img_height_model, img.shape[1]) + if img.shape[1] < img_width_model: + img = resize_image(img, img.shape[0], img_width_model) + + self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model) + margin = int(marginal_of_patch_percent * img_height_model) + width_mid = img_width_model - 2 * margin + height_mid = img_height_model - 2 * margin + img = img / 255. + #img = img.astype(np.float16) + img_h = img.shape[0] + img_w = img.shape[1] + prediction_true = np.zeros((img_h, img_w, 3)) + mask_true = np.zeros((img_h, img_w)) + nxf = img_w / float(width_mid) + nyf = img_h / float(height_mid) + nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf) + nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf) + + list_i_s = [] + list_j_s = [] + list_x_u = [] + list_x_d = [] + list_y_u = [] + list_y_d = [] + + batch_indexer = 0 + img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3)) + for i in range(nxf): + for j in range(nyf): + if i == 0: + index_x_d = i * width_mid + index_x_u = index_x_d + img_width_model + else: + index_x_d = i * width_mid + index_x_u = index_x_d + img_width_model + if j == 0: + index_y_d = j * height_mid + index_y_u = index_y_d + img_height_model + else: + index_y_d = j * height_mid + index_y_u = index_y_d + img_height_model + if index_x_u > img_w: + index_x_u = img_w + index_x_d = img_w - img_width_model + if index_y_u > img_h: + index_y_u = img_h + index_y_d = img_h - img_height_model + + list_i_s.append(i) + list_j_s.append(j) + list_x_u.append(index_x_u) + list_x_d.append(index_x_d) + list_y_d.append(index_y_d) + list_y_u.append(index_y_u) + + img_patch[batch_indexer,:,:,:] = img[index_y_d:index_y_u, index_x_d:index_x_u, :] + batch_indexer += 1 + + if (batch_indexer == n_batch_inference or + # last batch + i == nxf - 1 and j == nyf - 1): + self.logger.debug("predicting patches on %s", str(img_patch.shape)) + label_p_pred = model.predict(img_patch, verbose=0) + seg = np.argmax(label_p_pred, axis=3) + + if thresholding_for_some_classes_in_light_version: + seg_not_base = label_p_pred[:,:,:,4] + seg_not_base[seg_not_base>0.03] =1 + seg_not_base[seg_not_base<1] =0 + + seg_line = label_p_pred[:,:,:,3] + seg_line[seg_line>0.1] =1 + seg_line[seg_line<1] =0 + + seg_background = label_p_pred[:,:,:,0] + seg_background[seg_background>0.25] =1 + seg_background[seg_background<1] =0 + + seg[seg_not_base==1]=4 + seg[seg_background==1]=0 + seg[(seg_line==1) & (seg==0)]=3 + if thresholding_for_artificial_class_in_light_version: + seg_art = label_p_pred[:,:,:,2] + + seg_art[seg_art0] =1 + + ##seg[seg_art==1]=2 + + indexer_inside_batch = 0 + for i_batch, j_batch in zip(list_i_s, list_j_s): + seg_in = seg[indexer_inside_batch] + + if thresholding_for_artificial_class_in_light_version: + seg_in_art = seg_art[indexer_inside_batch] + + index_y_u_in = list_y_u[indexer_inside_batch] + index_y_d_in = list_y_d[indexer_inside_batch] + + index_x_u_in = list_x_u[indexer_inside_batch] + index_x_d_in = list_x_d[indexer_inside_batch] + + if i_batch == 0 and j_batch == 0: + prediction_true[index_y_d_in + 0:index_y_u_in - margin, + index_x_d_in + 0:index_x_u_in - margin] = \ + seg_in[0:-margin or None, + 0:-margin or None, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + 0:index_y_u_in - margin, + index_x_d_in + 0:index_x_u_in - margin, 1] = \ + seg_in_art[0:-margin or None, + 0:-margin or None] + + elif i_batch == nxf - 1 and j_batch == nyf - 1: + prediction_true[index_y_d_in + margin:index_y_u_in - 0, + index_x_d_in + margin:index_x_u_in - 0] = \ + seg_in[margin:, + margin:, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + margin:index_y_u_in - 0, + index_x_d_in + margin:index_x_u_in - 0, 1] = \ + seg_in_art[margin:, + margin:] + + elif i_batch == 0 and j_batch == nyf - 1: + prediction_true[index_y_d_in + margin:index_y_u_in - 0, + index_x_d_in + 0:index_x_u_in - margin] = \ + seg_in[margin:, + 0:-margin or None, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + margin:index_y_u_in - 0, + index_x_d_in + 0:index_x_u_in - margin, 1] = \ + seg_in_art[margin:, + 0:-margin or None] + + elif i_batch == nxf - 1 and j_batch == 0: + prediction_true[index_y_d_in + 0:index_y_u_in - margin, + index_x_d_in + margin:index_x_u_in - 0] = \ + seg_in[0:-margin or None, + margin:, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + 0:index_y_u_in - margin, + index_x_d_in + margin:index_x_u_in - 0, 1] = \ + seg_in_art[0:-margin or None, + margin:] + + elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1: + prediction_true[index_y_d_in + margin:index_y_u_in - margin, + index_x_d_in + 0:index_x_u_in - margin] = \ + seg_in[margin:-margin or None, + 0:-margin or None, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + margin:index_y_u_in - margin, + index_x_d_in + 0:index_x_u_in - margin, 1] = \ + seg_in_art[margin:-margin or None, + 0:-margin or None] + + elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1: + prediction_true[index_y_d_in + margin:index_y_u_in - margin, + index_x_d_in + margin:index_x_u_in - 0] = \ + seg_in[margin:-margin or None, + margin:, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + margin:index_y_u_in - margin, + index_x_d_in + margin:index_x_u_in - 0, 1] = \ + seg_in_art[margin:-margin or None, + margin:] + + elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0: + prediction_true[index_y_d_in + 0:index_y_u_in - margin, + index_x_d_in + margin:index_x_u_in - margin] = \ + seg_in[0:-margin or None, + margin:-margin or None, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + 0:index_y_u_in - margin, + index_x_d_in + margin:index_x_u_in - margin, 1] = \ + seg_in_art[0:-margin or None, + margin:-margin or None] + + elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1: + prediction_true[index_y_d_in + margin:index_y_u_in - 0, + index_x_d_in + margin:index_x_u_in - margin] = \ + seg_in[margin:, + margin:-margin or None, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + margin:index_y_u_in - 0, + index_x_d_in + margin:index_x_u_in - margin, 1] = \ + seg_in_art[margin:, + margin:-margin or None] + + else: + prediction_true[index_y_d_in + margin:index_y_u_in - margin, + index_x_d_in + margin:index_x_u_in - margin] = \ + seg_in[margin:-margin or None, + margin:-margin or None, + np.newaxis] + if thresholding_for_artificial_class_in_light_version: + prediction_true[index_y_d_in + margin:index_y_u_in - margin, + index_x_d_in + margin:index_x_u_in - margin, 1] = \ + seg_in_art[margin:-margin or None, + margin:-margin or None] + indexer_inside_batch += 1 + + + list_i_s = [] + list_j_s = [] + list_x_u = [] + list_x_d = [] + list_y_u = [] + list_y_d = [] + + batch_indexer = 0 + img_patch[:] = 0 + + prediction_true = prediction_true.astype(np.uint8) + + if thresholding_for_artificial_class_in_light_version: + kernel_min = np.ones((3, 3), np.uint8) + prediction_true[:,:,0][prediction_true[:,:,0]==2] = 0 + + skeleton_art = skeletonize(prediction_true[:,:,1]) + skeleton_art = skeleton_art*1 + + skeleton_art = skeleton_art.astype('uint8') + + skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1) + + prediction_true[:,:,0][skeleton_art==1]=2 + #del model + gc.collect() + return prediction_true + + def run_enhancement(self, light_version): + t_in = time.time() + self.logger.info("Resizing and enhancing image...") + is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \ + self.resize_and_enhance_image_with_column_classifier(light_version) + + self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ') + return img_res, is_image_enhanced, num_col_classifier, num_column_is_classified + + + def run_single(self): + t0 = time.time() + img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False) + + return img_res + + + def run(self, image_filename : Optional[str] = None, dir_in : Optional[str] = None, overwrite : bool = False): + """ + Get image and scales, then extract the page of scanned image + """ + self.logger.debug("enter run") + t0_tot = time.time() + + if dir_in: + self.ls_imgs = os.listdir(dir_in) + elif image_filename: + self.ls_imgs = [image_filename] + else: + raise ValueError("run requires either a single image filename or a directory") + + for img_filename in self.ls_imgs: + self.logger.info(img_filename) + t0 = time.time() + + self.reset_file_name_dir(os.path.join(dir_in or "", img_filename)) + #print("text region early -11 in %.1fs", time.time() - t0) + + if os.path.exists(self.output_filename): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", self.output_filename) + else: + self.logger.warning("will skip input for existing output file '%s'", self.output_filename) + continue + + image_enhanced = self.run_single() + img_enhanced_org_scale = resize_image(image_enhanced, self.h_org, self.w_org) + + cv2.imwrite(self.output_filename, img_enhanced_org_scale) +