mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-09-01 13:29:58 +02:00
image enhancer is integrated
This commit is contained in:
parent
928a548b70
commit
cc36694dfd
3 changed files with 830 additions and 229 deletions
|
@ -3,6 +3,7 @@ import click
|
||||||
from ocrd_utils import initLogging, getLevelName, getLogger
|
from ocrd_utils import initLogging, getLevelName, getLogger
|
||||||
from eynollah.eynollah import Eynollah, Eynollah_ocr
|
from eynollah.eynollah import Eynollah, Eynollah_ocr
|
||||||
from eynollah.sbb_binarize import SbbBinarizer
|
from eynollah.sbb_binarize import SbbBinarizer
|
||||||
|
from eynollah.image_enhancer import Enhancer
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def main():
|
def main():
|
||||||
|
@ -70,6 +71,74 @@ def binarization(patches, model_dir, input_image, output_image, dir_in, dir_out)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@click.option(
|
||||||
|
"--image",
|
||||||
|
"-i",
|
||||||
|
help="image filename",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--out",
|
||||||
|
"-o",
|
||||||
|
help="directory to write output xml data",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--overwrite",
|
||||||
|
"-O",
|
||||||
|
help="overwrite (instead of skipping) if output xml exists",
|
||||||
|
is_flag=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dir_in",
|
||||||
|
"-di",
|
||||||
|
help="directory of images",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
help="directory of models",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@click.option(
|
||||||
|
"--num_col_upper",
|
||||||
|
"-ncu",
|
||||||
|
help="lower limit of columns in document image",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num_col_lower",
|
||||||
|
"-ncl",
|
||||||
|
help="upper limit of columns in document image",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--log_level",
|
||||||
|
"-l",
|
||||||
|
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
|
||||||
|
help="Override log level globally to this",
|
||||||
|
)
|
||||||
|
|
||||||
|
def enhancement(image, out, overwrite, dir_in, model, num_col_upper, num_col_lower, log_level):
|
||||||
|
initLogging()
|
||||||
|
if log_level:
|
||||||
|
getLogger('enhancement').setLevel(getLevelName(log_level))
|
||||||
|
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||||
|
enhancer_object = Enhancer(
|
||||||
|
model,
|
||||||
|
logger=getLogger('enhancement'),
|
||||||
|
dir_out=out,
|
||||||
|
num_col_upper=num_col_upper,
|
||||||
|
num_col_lower=num_col_lower,
|
||||||
|
)
|
||||||
|
if dir_in:
|
||||||
|
enhancer_object.run(dir_in=dir_in, overwrite=overwrite)
|
||||||
|
else:
|
||||||
|
enhancer_object.run(image_filename=image, overwrite=overwrite)
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|
|
@ -3612,25 +3612,12 @@ class Eynollah:
|
||||||
|
|
||||||
inference_bs = 3
|
inference_bs = 3
|
||||||
|
|
||||||
cv2.imwrite('textregions.png', text_regions_p*50)
|
|
||||||
cv2.imwrite('sep.png', (text_regions_p[:,:]==6)*255)
|
|
||||||
|
|
||||||
ver_kernel = np.ones((5, 1), dtype=np.uint8)
|
ver_kernel = np.ones((5, 1), dtype=np.uint8)
|
||||||
hor_kernel = np.ones((1, 5), dtype=np.uint8)
|
hor_kernel = np.ones((1, 5), dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#separators = (text_regions_p[:,:]==6)*1
|
|
||||||
#text_regions_p[text_regions_p[:,:]==6] = 0
|
|
||||||
#separators = separators.astype('uint8')
|
|
||||||
|
|
||||||
#separators = cv2.erode(separators , hor_kernel, iterations=1)
|
|
||||||
#text_regions_p[separators[:,:]==1] = 6
|
|
||||||
|
|
||||||
#cv2.imwrite('sep_new.png', (text_regions_p[:,:]==6)*255)
|
|
||||||
|
|
||||||
min_cont_size_to_be_dilated = 10
|
min_cont_size_to_be_dilated = 10
|
||||||
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
|
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
|
||||||
cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _ = find_new_features_of_contours(contours_only_text_parent)
|
cx_conts, cy_conts, x_min_conts, x_max_conts, y_min_conts, y_max_conts, _ = find_new_features_of_contours(contours_only_text_parent)
|
||||||
args_cont_located = np.array(range(len(contours_only_text_parent)))
|
args_cont_located = np.array(range(len(contours_only_text_parent)))
|
||||||
|
|
||||||
|
@ -3672,7 +3659,6 @@ class Eynollah:
|
||||||
text_regions_p_textregions_dilated = cv2.dilate(text_regions_p_textregions_dilated , ver_kernel, iterations=5)
|
text_regions_p_textregions_dilated = cv2.dilate(text_regions_p_textregions_dilated , ver_kernel, iterations=5)
|
||||||
text_regions_p_textregions_dilated[text_regions_p[:,:]>1] = 0
|
text_regions_p_textregions_dilated[text_regions_p[:,:]>1] = 0
|
||||||
|
|
||||||
cv2.imwrite('text_regions_p_textregions_dilated.png', text_regions_p_textregions_dilated*255)
|
|
||||||
|
|
||||||
contours_only_dilated, hir_on_text_dilated = return_contours_of_image(text_regions_p_textregions_dilated)
|
contours_only_dilated, hir_on_text_dilated = return_contours_of_image(text_regions_p_textregions_dilated)
|
||||||
contours_only_dilated = return_parent_contours(contours_only_dilated, hir_on_text_dilated)
|
contours_only_dilated = return_parent_contours(contours_only_dilated, hir_on_text_dilated)
|
||||||
|
@ -3723,21 +3709,20 @@ class Eynollah:
|
||||||
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,
|
img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,
|
||||||
int(x_min_main[j]):int(x_max_main[j])] = 1
|
int(x_min_main[j]):int(x_max_main[j])] = 1
|
||||||
co_text_all_org = contours_only_text_parent + contours_only_text_parent_h
|
co_text_all_org = contours_only_text_parent + contours_only_text_parent_h
|
||||||
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
|
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
|
||||||
co_text_all = contours_only_dilated + contours_only_text_parent_h
|
co_text_all = contours_only_dilated + contours_only_text_parent_h
|
||||||
else:
|
else:
|
||||||
co_text_all = contours_only_text_parent + contours_only_text_parent_h
|
co_text_all = contours_only_text_parent + contours_only_text_parent_h
|
||||||
else:
|
else:
|
||||||
co_text_all_org = contours_only_text_parent
|
co_text_all_org = contours_only_text_parent
|
||||||
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
|
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
|
||||||
co_text_all = contours_only_dilated
|
co_text_all = contours_only_dilated
|
||||||
else:
|
else:
|
||||||
co_text_all = contours_only_text_parent
|
co_text_all = contours_only_text_parent
|
||||||
|
|
||||||
if not len(co_text_all):
|
if not len(co_text_all):
|
||||||
return [], []
|
return [], []
|
||||||
print(len(co_text_all), "co_text_all")
|
|
||||||
print(len(co_text_all_org), "co_text_all_org")
|
|
||||||
labels_con = np.zeros((int(y_len /6.), int(x_len/6.), len(co_text_all)), dtype=bool)
|
labels_con = np.zeros((int(y_len /6.), int(x_len/6.), len(co_text_all)), dtype=bool)
|
||||||
co_text_all = [(i/6).astype(int) for i in co_text_all]
|
co_text_all = [(i/6).astype(int) for i in co_text_all]
|
||||||
for i in range(len(co_text_all)):
|
for i in range(len(co_text_all)):
|
||||||
|
@ -3805,7 +3790,7 @@ class Eynollah:
|
||||||
|
|
||||||
ordered = [i[0] for i in ordered]
|
ordered = [i[0] for i in ordered]
|
||||||
|
|
||||||
if len(contours_only_text_parent)>min_cont_size_to_be_dilated:
|
if len(contours_only_text_parent)>min_cont_size_to_be_dilated and self.light_version:
|
||||||
org_contours_indexes = []
|
org_contours_indexes = []
|
||||||
for ind in range(len(ordered)):
|
for ind in range(len(ordered)):
|
||||||
region_with_curr_order = ordered[ind]
|
region_with_curr_order = ordered[ind]
|
||||||
|
@ -3823,215 +3808,6 @@ class Eynollah:
|
||||||
else:
|
else:
|
||||||
region_ids = ['region_%04d' % i for i in range(len(co_text_all_org))]
|
region_ids = ['region_%04d' % i for i in range(len(co_text_all_org))]
|
||||||
return ordered, region_ids
|
return ordered, region_ids
|
||||||
|
|
||||||
|
|
||||||
####def return_start_and_end_of_common_text_of_textline_ocr(self, textline_image, ind_tot):
|
|
||||||
####width = np.shape(textline_image)[1]
|
|
||||||
####height = np.shape(textline_image)[0]
|
|
||||||
####common_window = int(0.2*width)
|
|
||||||
|
|
||||||
####width1 = int ( width/2. - common_window )
|
|
||||||
####width2 = int ( width/2. + common_window )
|
|
||||||
|
|
||||||
####img_sum = np.sum(textline_image[:,:,0], axis=0)
|
|
||||||
####sum_smoothed = gaussian_filter1d(img_sum, 3)
|
|
||||||
|
|
||||||
####peaks_real, _ = find_peaks(sum_smoothed, height=0)
|
|
||||||
####if len(peaks_real)>70:
|
|
||||||
|
|
||||||
####peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
|
|
||||||
|
|
||||||
####arg_sort = np.argsort(sum_smoothed[peaks_real])
|
|
||||||
####arg_sort4 =arg_sort[::-1][:4]
|
|
||||||
####peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
|
|
||||||
####argsort_sorted = np.argsort(peaks_sort_4)
|
|
||||||
|
|
||||||
####first_4_sorted = peaks_sort_4[argsort_sorted]
|
|
||||||
####y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
|
|
||||||
#####print(first_4_sorted,'first_4_sorted')
|
|
||||||
|
|
||||||
####arg_sortnew = np.argsort(y_4_sorted)
|
|
||||||
####peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] )
|
|
||||||
|
|
||||||
#####plt.figure(ind_tot)
|
|
||||||
#####plt.imshow(textline_image)
|
|
||||||
#####plt.plot([peaks_final[0], peaks_final[0]], [0, height-1])
|
|
||||||
#####plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
|
|
||||||
#####plt.savefig('./'+str(ind_tot)+'.png')
|
|
||||||
|
|
||||||
####return peaks_final[0], peaks_final[1]
|
|
||||||
####else:
|
|
||||||
####pass
|
|
||||||
|
|
||||||
##def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self, textline_image, ind_tot):
|
|
||||||
##width = np.shape(textline_image)[1]
|
|
||||||
##height = np.shape(textline_image)[0]
|
|
||||||
##common_window = int(0.06*width)
|
|
||||||
|
|
||||||
##width1 = int ( width/2. - common_window )
|
|
||||||
##width2 = int ( width/2. + common_window )
|
|
||||||
|
|
||||||
##img_sum = np.sum(textline_image[:,:,0], axis=0)
|
|
||||||
##sum_smoothed = gaussian_filter1d(img_sum, 3)
|
|
||||||
|
|
||||||
##peaks_real, _ = find_peaks(sum_smoothed, height=0)
|
|
||||||
##if len(peaks_real)>70:
|
|
||||||
###print(len(peaks_real), 'len(peaks_real)')
|
|
||||||
|
|
||||||
##peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
|
|
||||||
|
|
||||||
##arg_max = np.argmax(sum_smoothed[peaks_real])
|
|
||||||
##peaks_final = peaks_real[arg_max]
|
|
||||||
|
|
||||||
###plt.figure(ind_tot)
|
|
||||||
###plt.imshow(textline_image)
|
|
||||||
###plt.plot([peaks_final, peaks_final], [0, height-1])
|
|
||||||
####plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
|
|
||||||
###plt.savefig('./'+str(ind_tot)+'.png')
|
|
||||||
|
|
||||||
##return peaks_final
|
|
||||||
##else:
|
|
||||||
##return None
|
|
||||||
|
|
||||||
###def return_start_and_end_of_common_text_of_textline_ocr_new_splitted(
|
|
||||||
###self, peaks_real, sum_smoothed, start_split, end_split):
|
|
||||||
|
|
||||||
###peaks_real = peaks_real[(peaks_real<end_split) & (peaks_real>start_split)]
|
|
||||||
|
|
||||||
###arg_sort = np.argsort(sum_smoothed[peaks_real])
|
|
||||||
###arg_sort4 =arg_sort[::-1][:4]
|
|
||||||
###peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
|
|
||||||
###argsort_sorted = np.argsort(peaks_sort_4)
|
|
||||||
|
|
||||||
###first_4_sorted = peaks_sort_4[argsort_sorted]
|
|
||||||
###y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
|
|
||||||
####print(first_4_sorted,'first_4_sorted')
|
|
||||||
|
|
||||||
###arg_sortnew = np.argsort(y_4_sorted)
|
|
||||||
###peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] )
|
|
||||||
###return peaks_final[0]
|
|
||||||
|
|
||||||
###def return_start_and_end_of_common_text_of_textline_ocr_new(self, textline_image, ind_tot):
|
|
||||||
###width = np.shape(textline_image)[1]
|
|
||||||
###height = np.shape(textline_image)[0]
|
|
||||||
###common_window = int(0.15*width)
|
|
||||||
|
|
||||||
###width1 = int ( width/2. - common_window )
|
|
||||||
###width2 = int ( width/2. + common_window )
|
|
||||||
###mid = int(width/2.)
|
|
||||||
|
|
||||||
###img_sum = np.sum(textline_image[:,:,0], axis=0)
|
|
||||||
###sum_smoothed = gaussian_filter1d(img_sum, 3)
|
|
||||||
|
|
||||||
###peaks_real, _ = find_peaks(sum_smoothed, height=0)
|
|
||||||
###if len(peaks_real)>70:
|
|
||||||
###peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(
|
|
||||||
###peaks_real, sum_smoothed, width1, mid+2)
|
|
||||||
###peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(
|
|
||||||
###peaks_real, sum_smoothed, mid-2, width2)
|
|
||||||
|
|
||||||
####plt.figure(ind_tot)
|
|
||||||
####plt.imshow(textline_image)
|
|
||||||
####plt.plot([peak_start, peak_start], [0, height-1])
|
|
||||||
####plt.plot([peak_end, peak_end], [0, height-1])
|
|
||||||
####plt.savefig('./'+str(ind_tot)+'.png')
|
|
||||||
|
|
||||||
###return peak_start, peak_end
|
|
||||||
###else:
|
|
||||||
###pass
|
|
||||||
|
|
||||||
##def return_ocr_of_textline_without_common_section(
|
|
||||||
##self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
|
|
||||||
|
|
||||||
##if h2w_ratio > 0.05:
|
|
||||||
##pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
|
||||||
##generated_ids = model_ocr.generate(pixel_values.to(device))
|
|
||||||
##generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
||||||
##else:
|
|
||||||
###width = np.shape(textline_image)[1]
|
|
||||||
###height = np.shape(textline_image)[0]
|
|
||||||
###common_window = int(0.3*width)
|
|
||||||
###width1 = int ( width/2. - common_window )
|
|
||||||
###width2 = int ( width/2. + common_window )
|
|
||||||
|
|
||||||
##split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(
|
|
||||||
##textline_image, ind_tot)
|
|
||||||
##if split_point:
|
|
||||||
##image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height))
|
|
||||||
##image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height))
|
|
||||||
|
|
||||||
###pixel_values1 = processor(image1, return_tensors="pt").pixel_values
|
|
||||||
###pixel_values2 = processor(image2, return_tensors="pt").pixel_values
|
|
||||||
|
|
||||||
##pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values
|
|
||||||
##generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device))
|
|
||||||
##generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
|
|
||||||
|
|
||||||
###print(generated_text_merged,'generated_text_merged')
|
|
||||||
|
|
||||||
###generated_ids1 = model_ocr.generate(pixel_values1.to(device))
|
|
||||||
###generated_ids2 = model_ocr.generate(pixel_values2.to(device))
|
|
||||||
|
|
||||||
###generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
|
|
||||||
###generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
|
|
||||||
|
|
||||||
###generated_text = generated_text1 + ' ' + generated_text2
|
|
||||||
##generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1]
|
|
||||||
|
|
||||||
###print(generated_text1,'generated_text1')
|
|
||||||
###print(generated_text2, 'generated_text2')
|
|
||||||
###print('########################################')
|
|
||||||
##else:
|
|
||||||
##pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
|
||||||
##generated_ids = model_ocr.generate(pixel_values.to(device))
|
|
||||||
##generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
||||||
|
|
||||||
###print(generated_text,'generated_text')
|
|
||||||
###print('########################################')
|
|
||||||
##return generated_text
|
|
||||||
|
|
||||||
###def return_ocr_of_textline(
|
|
||||||
###self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
|
|
||||||
|
|
||||||
###if h2w_ratio > 0.05:
|
|
||||||
###pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
|
||||||
###generated_ids = model_ocr.generate(pixel_values.to(device))
|
|
||||||
###generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
||||||
###else:
|
|
||||||
####width = np.shape(textline_image)[1]
|
|
||||||
####height = np.shape(textline_image)[0]
|
|
||||||
####common_window = int(0.3*width)
|
|
||||||
####width1 = int ( width/2. - common_window )
|
|
||||||
####width2 = int ( width/2. + common_window )
|
|
||||||
|
|
||||||
###try:
|
|
||||||
###width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot)
|
|
||||||
|
|
||||||
###image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height))
|
|
||||||
###image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height))
|
|
||||||
|
|
||||||
###pixel_values1 = processor(image1, return_tensors="pt").pixel_values
|
|
||||||
###pixel_values2 = processor(image2, return_tensors="pt").pixel_values
|
|
||||||
|
|
||||||
###generated_ids1 = model_ocr.generate(pixel_values1.to(device))
|
|
||||||
###generated_ids2 = model_ocr.generate(pixel_values2.to(device))
|
|
||||||
|
|
||||||
###generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
|
|
||||||
###generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
|
|
||||||
####print(generated_text1,'generated_text1')
|
|
||||||
####print(generated_text2, 'generated_text2')
|
|
||||||
####print('########################################')
|
|
||||||
|
|
||||||
###match = sq(None, generated_text1, generated_text2).find_longest_match(
|
|
||||||
###0, len(generated_text1), 0, len(generated_text2))
|
|
||||||
###generated_text = generated_text1 + generated_text2[match.b+match.size:]
|
|
||||||
###except:
|
|
||||||
###pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
|
||||||
###generated_ids = model_ocr.generate(pixel_values.to(device))
|
|
||||||
###generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
||||||
|
|
||||||
###return generated_text
|
|
||||||
|
|
||||||
|
|
||||||
def return_list_of_contours_with_desired_order(self, ls_cons, sorted_indexes):
|
def return_list_of_contours_with_desired_order(self, ls_cons, sorted_indexes):
|
||||||
return [ls_cons[sorted_indexes[index]] for index in range(len(sorted_indexes))]
|
return [ls_cons[sorted_indexes[index]] for index in range(len(sorted_indexes))]
|
||||||
|
|
756
src/eynollah/image_enhancer.py
Normal file
756
src/eynollah/image_enhancer.py
Normal file
|
@ -0,0 +1,756 @@
|
||||||
|
"""
|
||||||
|
Image enhancer. The output can be written as same scale of input or in new predicted scale.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from logging import Logger
|
||||||
|
from difflib import SequenceMatcher as sq
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import atexit
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
import gc
|
||||||
|
import copy
|
||||||
|
from loky import ProcessPoolExecutor
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from ocrd import OcrdPage
|
||||||
|
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
||||||
|
import statistics
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
from .utils.resize import resize_image
|
||||||
|
from .utils import (
|
||||||
|
crop_image_inside_box
|
||||||
|
)
|
||||||
|
|
||||||
|
DPI_THRESHOLD = 298
|
||||||
|
KERNEL = np.ones((5, 5), np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
class Enhancer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dir_models : str,
|
||||||
|
dir_out : Optional[str] = None,
|
||||||
|
num_col_upper : Optional[int] = None,
|
||||||
|
num_col_lower : Optional[int] = None,
|
||||||
|
logger : Optional[Logger] = None,
|
||||||
|
):
|
||||||
|
self.dir_out = dir_out
|
||||||
|
self.input_binary = False
|
||||||
|
self.light_version = False
|
||||||
|
if num_col_upper:
|
||||||
|
self.num_col_upper = int(num_col_upper)
|
||||||
|
else:
|
||||||
|
self.num_col_upper = num_col_upper
|
||||||
|
if num_col_lower:
|
||||||
|
self.num_col_lower = int(num_col_lower)
|
||||||
|
else:
|
||||||
|
self.num_col_lower = num_col_lower
|
||||||
|
|
||||||
|
self.logger = logger if logger else getLogger('enhancement')
|
||||||
|
# for parallelization of CPU-intensive tasks:
|
||||||
|
self.executor = ProcessPoolExecutor(max_workers=cpu_count(), timeout=1200)
|
||||||
|
atexit.register(self.executor.shutdown)
|
||||||
|
self.dir_models = dir_models
|
||||||
|
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
|
||||||
|
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
|
||||||
|
self.model_page_dir = dir_models + "/eynollah-page-extraction_20210425"
|
||||||
|
|
||||||
|
try:
|
||||||
|
for device in tf.config.list_physical_devices('GPU'):
|
||||||
|
tf.config.experimental.set_memory_growth(device, True)
|
||||||
|
except:
|
||||||
|
self.logger.warning("no GPU device available")
|
||||||
|
|
||||||
|
self.model_page = self.our_load_model(self.model_page_dir)
|
||||||
|
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
|
||||||
|
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
|
||||||
|
|
||||||
|
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
|
||||||
|
ret = {}
|
||||||
|
t_c0 = time.time()
|
||||||
|
if image_filename:
|
||||||
|
ret['img'] = cv2.imread(image_filename)
|
||||||
|
if self.light_version:
|
||||||
|
self.dpi = 100
|
||||||
|
else:
|
||||||
|
self.dpi = 0#check_dpi(image_filename)
|
||||||
|
else:
|
||||||
|
ret['img'] = pil2cv(image_pil)
|
||||||
|
if self.light_version:
|
||||||
|
self.dpi = 100
|
||||||
|
else:
|
||||||
|
self.dpi = 0#check_dpi(image_pil)
|
||||||
|
ret['img_grayscale'] = cv2.cvtColor(ret['img'], cv2.COLOR_BGR2GRAY)
|
||||||
|
for prefix in ('', '_grayscale'):
|
||||||
|
ret[f'img{prefix}_uint8'] = ret[f'img{prefix}'].astype(np.uint8)
|
||||||
|
self._imgs = ret
|
||||||
|
if dpi is not None:
|
||||||
|
self.dpi = dpi
|
||||||
|
|
||||||
|
def reset_file_name_dir(self, image_filename):
|
||||||
|
t_c = time.time()
|
||||||
|
self.cache_images(image_filename=image_filename)
|
||||||
|
self.output_filename = os.path.join(self.dir_out, Path(image_filename).stem +'.png')
|
||||||
|
|
||||||
|
def imread(self, grayscale=False, uint8=True):
|
||||||
|
key = 'img'
|
||||||
|
if grayscale:
|
||||||
|
key += '_grayscale'
|
||||||
|
if uint8:
|
||||||
|
key += '_uint8'
|
||||||
|
return self._imgs[key].copy()
|
||||||
|
|
||||||
|
def isNaN(self, num):
|
||||||
|
return num != num
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def our_load_model(model_file):
|
||||||
|
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||||
|
# prefer SavedModel over HDF5 format if it exists
|
||||||
|
model_file = model_file[:-3]
|
||||||
|
try:
|
||||||
|
model = load_model(model_file, compile=False)
|
||||||
|
except:
|
||||||
|
model = load_model(model_file, compile=False, custom_objects={
|
||||||
|
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
|
return model
|
||||||
|
|
||||||
|
def predict_enhancement(self, img):
|
||||||
|
self.logger.debug("enter predict_enhancement")
|
||||||
|
|
||||||
|
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
|
||||||
|
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
|
||||||
|
if img.shape[0] < img_height_model:
|
||||||
|
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
|
||||||
|
if img.shape[1] < img_width_model:
|
||||||
|
img = cv2.resize(img, (img_height_model, img.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
margin = int(0.1 * img_width_model)
|
||||||
|
width_mid = img_width_model - 2 * margin
|
||||||
|
height_mid = img_height_model - 2 * margin
|
||||||
|
img = img / 255.
|
||||||
|
img_h = img.shape[0]
|
||||||
|
img_w = img.shape[1]
|
||||||
|
|
||||||
|
prediction_true = np.zeros((img_h, img_w, 3))
|
||||||
|
nxf = img_w / float(width_mid)
|
||||||
|
nyf = img_h / float(height_mid)
|
||||||
|
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
|
||||||
|
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
|
||||||
|
|
||||||
|
for i in range(nxf):
|
||||||
|
for j in range(nyf):
|
||||||
|
if i == 0:
|
||||||
|
index_x_d = i * width_mid
|
||||||
|
index_x_u = index_x_d + img_width_model
|
||||||
|
else:
|
||||||
|
index_x_d = i * width_mid
|
||||||
|
index_x_u = index_x_d + img_width_model
|
||||||
|
if j == 0:
|
||||||
|
index_y_d = j * height_mid
|
||||||
|
index_y_u = index_y_d + img_height_model
|
||||||
|
else:
|
||||||
|
index_y_d = j * height_mid
|
||||||
|
index_y_u = index_y_d + img_height_model
|
||||||
|
|
||||||
|
if index_x_u > img_w:
|
||||||
|
index_x_u = img_w
|
||||||
|
index_x_d = img_w - img_width_model
|
||||||
|
if index_y_u > img_h:
|
||||||
|
index_y_u = img_h
|
||||||
|
index_y_d = img_h - img_height_model
|
||||||
|
|
||||||
|
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||||
|
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
|
||||||
|
seg = label_p_pred[0, :, :, :] * 255
|
||||||
|
|
||||||
|
if i == 0 and j == 0:
|
||||||
|
prediction_true[index_y_d + 0:index_y_u - margin,
|
||||||
|
index_x_d + 0:index_x_u - margin] = \
|
||||||
|
seg[0:-margin or None,
|
||||||
|
0:-margin or None]
|
||||||
|
elif i == nxf - 1 and j == nyf - 1:
|
||||||
|
prediction_true[index_y_d + margin:index_y_u - 0,
|
||||||
|
index_x_d + margin:index_x_u - 0] = \
|
||||||
|
seg[margin:,
|
||||||
|
margin:]
|
||||||
|
elif i == 0 and j == nyf - 1:
|
||||||
|
prediction_true[index_y_d + margin:index_y_u - 0,
|
||||||
|
index_x_d + 0:index_x_u - margin] = \
|
||||||
|
seg[margin:,
|
||||||
|
0:-margin or None]
|
||||||
|
elif i == nxf - 1 and j == 0:
|
||||||
|
prediction_true[index_y_d + 0:index_y_u - margin,
|
||||||
|
index_x_d + margin:index_x_u - 0] = \
|
||||||
|
seg[0:-margin or None,
|
||||||
|
margin:]
|
||||||
|
elif i == 0 and j != 0 and j != nyf - 1:
|
||||||
|
prediction_true[index_y_d + margin:index_y_u - margin,
|
||||||
|
index_x_d + 0:index_x_u - margin] = \
|
||||||
|
seg[margin:-margin or None,
|
||||||
|
0:-margin or None]
|
||||||
|
elif i == nxf - 1 and j != 0 and j != nyf - 1:
|
||||||
|
prediction_true[index_y_d + margin:index_y_u - margin,
|
||||||
|
index_x_d + margin:index_x_u - 0] = \
|
||||||
|
seg[margin:-margin or None,
|
||||||
|
margin:]
|
||||||
|
elif i != 0 and i != nxf - 1 and j == 0:
|
||||||
|
prediction_true[index_y_d + 0:index_y_u - margin,
|
||||||
|
index_x_d + margin:index_x_u - margin] = \
|
||||||
|
seg[0:-margin or None,
|
||||||
|
margin:-margin or None]
|
||||||
|
elif i != 0 and i != nxf - 1 and j == nyf - 1:
|
||||||
|
prediction_true[index_y_d + margin:index_y_u - 0,
|
||||||
|
index_x_d + margin:index_x_u - margin] = \
|
||||||
|
seg[margin:,
|
||||||
|
margin:-margin or None]
|
||||||
|
else:
|
||||||
|
prediction_true[index_y_d + margin:index_y_u - margin,
|
||||||
|
index_x_d + margin:index_x_u - margin] = \
|
||||||
|
seg[margin:-margin or None,
|
||||||
|
margin:-margin or None]
|
||||||
|
|
||||||
|
prediction_true = prediction_true.astype(int)
|
||||||
|
return prediction_true
|
||||||
|
|
||||||
|
def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred):
|
||||||
|
self.logger.debug("enter calculate_width_height_by_columns")
|
||||||
|
if num_col == 1 and width_early < 1100:
|
||||||
|
img_w_new = 2000
|
||||||
|
elif num_col == 1 and width_early >= 2500:
|
||||||
|
img_w_new = 2000
|
||||||
|
elif num_col == 1 and width_early >= 1100 and width_early < 2500:
|
||||||
|
img_w_new = width_early
|
||||||
|
elif num_col == 2 and width_early < 2000:
|
||||||
|
img_w_new = 2400
|
||||||
|
elif num_col == 2 and width_early >= 3500:
|
||||||
|
img_w_new = 2400
|
||||||
|
elif num_col == 2 and width_early >= 2000 and width_early < 3500:
|
||||||
|
img_w_new = width_early
|
||||||
|
elif num_col == 3 and width_early < 2000:
|
||||||
|
img_w_new = 3000
|
||||||
|
elif num_col == 3 and width_early >= 4000:
|
||||||
|
img_w_new = 3000
|
||||||
|
elif num_col == 3 and width_early >= 2000 and width_early < 4000:
|
||||||
|
img_w_new = width_early
|
||||||
|
elif num_col == 4 and width_early < 2500:
|
||||||
|
img_w_new = 4000
|
||||||
|
elif num_col == 4 and width_early >= 5000:
|
||||||
|
img_w_new = 4000
|
||||||
|
elif num_col == 4 and width_early >= 2500 and width_early < 5000:
|
||||||
|
img_w_new = width_early
|
||||||
|
elif num_col == 5 and width_early < 3700:
|
||||||
|
img_w_new = 5000
|
||||||
|
elif num_col == 5 and width_early >= 7000:
|
||||||
|
img_w_new = 5000
|
||||||
|
elif num_col == 5 and width_early >= 3700 and width_early < 7000:
|
||||||
|
img_w_new = width_early
|
||||||
|
elif num_col == 6 and width_early < 4500:
|
||||||
|
img_w_new = 6500 # 5400
|
||||||
|
else:
|
||||||
|
img_w_new = width_early
|
||||||
|
img_h_new = img_w_new * img.shape[0] // img.shape[1]
|
||||||
|
|
||||||
|
if label_p_pred[0][int(num_col - 1)] < 0.9 and img_w_new < width_early:
|
||||||
|
img_new = np.copy(img)
|
||||||
|
num_column_is_classified = False
|
||||||
|
#elif label_p_pred[0][int(num_col - 1)] < 0.8 and img_h_new >= 8000:
|
||||||
|
elif img_h_new >= 8000:
|
||||||
|
img_new = np.copy(img)
|
||||||
|
num_column_is_classified = False
|
||||||
|
else:
|
||||||
|
img_new = resize_image(img, img_h_new, img_w_new)
|
||||||
|
num_column_is_classified = True
|
||||||
|
|
||||||
|
return img_new, num_column_is_classified
|
||||||
|
|
||||||
|
def early_page_for_num_of_column_classification(self,img_bin):
|
||||||
|
self.logger.debug("enter early_page_for_num_of_column_classification")
|
||||||
|
if self.input_binary:
|
||||||
|
img = np.copy(img_bin).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
img = self.imread()
|
||||||
|
img = cv2.GaussianBlur(img, (5, 5), 0)
|
||||||
|
img_page_prediction = self.do_prediction(False, img, self.model_page)
|
||||||
|
|
||||||
|
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
|
||||||
|
_, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||||
|
thresh = cv2.dilate(thresh, KERNEL, iterations=3)
|
||||||
|
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
if len(contours)>0:
|
||||||
|
cnt_size = np.array([cv2.contourArea(contours[j])
|
||||||
|
for j in range(len(contours))])
|
||||||
|
cnt = contours[np.argmax(cnt_size)]
|
||||||
|
box = cv2.boundingRect(cnt)
|
||||||
|
else:
|
||||||
|
box = [0, 0, img.shape[1], img.shape[0]]
|
||||||
|
cropped_page, page_coord = crop_image_inside_box(box, img)
|
||||||
|
|
||||||
|
self.logger.debug("exit early_page_for_num_of_column_classification")
|
||||||
|
return cropped_page, page_coord
|
||||||
|
|
||||||
|
def calculate_width_height_by_columns_1_2(self, img, num_col, width_early, label_p_pred):
|
||||||
|
self.logger.debug("enter calculate_width_height_by_columns")
|
||||||
|
if num_col == 1:
|
||||||
|
img_w_new = 1000
|
||||||
|
else:
|
||||||
|
img_w_new = 1300
|
||||||
|
img_h_new = img_w_new * img.shape[0] // img.shape[1]
|
||||||
|
|
||||||
|
if label_p_pred[0][int(num_col - 1)] < 0.9 and img_w_new < width_early:
|
||||||
|
img_new = np.copy(img)
|
||||||
|
num_column_is_classified = False
|
||||||
|
#elif label_p_pred[0][int(num_col - 1)] < 0.8 and img_h_new >= 8000:
|
||||||
|
elif img_h_new >= 8000:
|
||||||
|
img_new = np.copy(img)
|
||||||
|
num_column_is_classified = False
|
||||||
|
else:
|
||||||
|
img_new = resize_image(img, img_h_new, img_w_new)
|
||||||
|
num_column_is_classified = True
|
||||||
|
|
||||||
|
return img_new, num_column_is_classified
|
||||||
|
|
||||||
|
def resize_and_enhance_image_with_column_classifier(self, light_version):
|
||||||
|
self.logger.debug("enter resize_and_enhance_image_with_column_classifier")
|
||||||
|
dpi = 0#self.dpi
|
||||||
|
self.logger.info("Detected %s DPI", dpi)
|
||||||
|
if self.input_binary:
|
||||||
|
img = self.imread()
|
||||||
|
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
|
||||||
|
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
|
||||||
|
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
|
||||||
|
img= np.copy(prediction_bin)
|
||||||
|
img_bin = prediction_bin
|
||||||
|
else:
|
||||||
|
img = self.imread()
|
||||||
|
self.h_org, self.w_org = img.shape[:2]
|
||||||
|
img_bin = None
|
||||||
|
|
||||||
|
width_early = img.shape[1]
|
||||||
|
t1 = time.time()
|
||||||
|
_, page_coord = self.early_page_for_num_of_column_classification(img_bin)
|
||||||
|
|
||||||
|
self.image_page_org_size = img[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3], :]
|
||||||
|
self.page_coord = page_coord
|
||||||
|
|
||||||
|
if self.num_col_upper and not self.num_col_lower:
|
||||||
|
num_col = self.num_col_upper
|
||||||
|
label_p_pred = [np.ones(6)]
|
||||||
|
elif self.num_col_lower and not self.num_col_upper:
|
||||||
|
num_col = self.num_col_lower
|
||||||
|
label_p_pred = [np.ones(6)]
|
||||||
|
elif not self.num_col_upper and not self.num_col_lower:
|
||||||
|
if self.input_binary:
|
||||||
|
img_in = np.copy(img)
|
||||||
|
img_in = img_in / 255.0
|
||||||
|
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
|
||||||
|
img_in = img_in.reshape(1, 448, 448, 3)
|
||||||
|
else:
|
||||||
|
img_1ch = self.imread(grayscale=True)
|
||||||
|
width_early = img_1ch.shape[1]
|
||||||
|
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
|
||||||
|
|
||||||
|
img_1ch = img_1ch / 255.0
|
||||||
|
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
|
||||||
|
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
|
||||||
|
img_in[0, :, :, 0] = img_1ch[:, :]
|
||||||
|
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||||
|
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||||
|
|
||||||
|
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||||
|
num_col = np.argmax(label_p_pred[0]) + 1
|
||||||
|
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
|
||||||
|
if self.input_binary:
|
||||||
|
img_in = np.copy(img)
|
||||||
|
img_in = img_in / 255.0
|
||||||
|
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
|
||||||
|
img_in = img_in.reshape(1, 448, 448, 3)
|
||||||
|
else:
|
||||||
|
img_1ch = self.imread(grayscale=True)
|
||||||
|
width_early = img_1ch.shape[1]
|
||||||
|
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
|
||||||
|
|
||||||
|
img_1ch = img_1ch / 255.0
|
||||||
|
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
|
||||||
|
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
|
||||||
|
img_in[0, :, :, 0] = img_1ch[:, :]
|
||||||
|
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||||
|
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||||
|
|
||||||
|
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||||
|
num_col = np.argmax(label_p_pred[0]) + 1
|
||||||
|
|
||||||
|
if num_col > self.num_col_upper:
|
||||||
|
num_col = self.num_col_upper
|
||||||
|
label_p_pred = [np.ones(6)]
|
||||||
|
if num_col < self.num_col_lower:
|
||||||
|
num_col = self.num_col_lower
|
||||||
|
label_p_pred = [np.ones(6)]
|
||||||
|
else:
|
||||||
|
num_col = self.num_col_upper
|
||||||
|
label_p_pred = [np.ones(6)]
|
||||||
|
|
||||||
|
self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))
|
||||||
|
|
||||||
|
if dpi < DPI_THRESHOLD:
|
||||||
|
if light_version and num_col in (1,2):
|
||||||
|
img_new, num_column_is_classified = self.calculate_width_height_by_columns_1_2(
|
||||||
|
img, num_col, width_early, label_p_pred)
|
||||||
|
else:
|
||||||
|
img_new, num_column_is_classified = self.calculate_width_height_by_columns(
|
||||||
|
img, num_col, width_early, label_p_pred)
|
||||||
|
if light_version:
|
||||||
|
image_res = np.copy(img_new)
|
||||||
|
else:
|
||||||
|
image_res = self.predict_enhancement(img_new)
|
||||||
|
is_image_enhanced = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
num_column_is_classified = True
|
||||||
|
image_res = np.copy(img)
|
||||||
|
is_image_enhanced = False
|
||||||
|
|
||||||
|
self.logger.debug("exit resize_and_enhance_image_with_column_classifier")
|
||||||
|
return is_image_enhanced, img, image_res, num_col, num_column_is_classified, img_bin
|
||||||
|
def do_prediction(
|
||||||
|
self, patches, img, model,
|
||||||
|
n_batch_inference=1, marginal_of_patch_percent=0.1,
|
||||||
|
thresholding_for_some_classes_in_light_version=False,
|
||||||
|
thresholding_for_artificial_class_in_light_version=False, thresholding_for_fl_light_version=False, threshold_art_class_textline=0.1):
|
||||||
|
|
||||||
|
self.logger.debug("enter do_prediction")
|
||||||
|
img_height_model = model.layers[-1].output_shape[1]
|
||||||
|
img_width_model = model.layers[-1].output_shape[2]
|
||||||
|
|
||||||
|
if not patches:
|
||||||
|
img_h_page = img.shape[0]
|
||||||
|
img_w_page = img.shape[1]
|
||||||
|
img = img / float(255.0)
|
||||||
|
img = resize_image(img, img_height_model, img_width_model)
|
||||||
|
|
||||||
|
label_p_pred = model.predict(img[np.newaxis], verbose=0)
|
||||||
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
seg_art = label_p_pred[0,:,:,2]
|
||||||
|
|
||||||
|
seg_art[seg_art<threshold_art_class_textline] = 0
|
||||||
|
seg_art[seg_art>0] =1
|
||||||
|
|
||||||
|
skeleton_art = skeletonize(seg_art)
|
||||||
|
skeleton_art = skeleton_art*1
|
||||||
|
|
||||||
|
seg[skeleton_art==1]=2
|
||||||
|
|
||||||
|
if thresholding_for_fl_light_version:
|
||||||
|
seg_header = label_p_pred[0,:,:,2]
|
||||||
|
|
||||||
|
seg_header[seg_header<0.2] = 0
|
||||||
|
seg_header[seg_header>0] =1
|
||||||
|
|
||||||
|
seg[seg_header==1]=2
|
||||||
|
|
||||||
|
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
|
prediction_true = resize_image(seg_color, img_h_page, img_w_page).astype(np.uint8)
|
||||||
|
return prediction_true
|
||||||
|
|
||||||
|
if img.shape[0] < img_height_model:
|
||||||
|
img = resize_image(img, img_height_model, img.shape[1])
|
||||||
|
if img.shape[1] < img_width_model:
|
||||||
|
img = resize_image(img, img.shape[0], img_width_model)
|
||||||
|
|
||||||
|
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
|
||||||
|
margin = int(marginal_of_patch_percent * img_height_model)
|
||||||
|
width_mid = img_width_model - 2 * margin
|
||||||
|
height_mid = img_height_model - 2 * margin
|
||||||
|
img = img / 255.
|
||||||
|
#img = img.astype(np.float16)
|
||||||
|
img_h = img.shape[0]
|
||||||
|
img_w = img.shape[1]
|
||||||
|
prediction_true = np.zeros((img_h, img_w, 3))
|
||||||
|
mask_true = np.zeros((img_h, img_w))
|
||||||
|
nxf = img_w / float(width_mid)
|
||||||
|
nyf = img_h / float(height_mid)
|
||||||
|
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
|
||||||
|
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
|
||||||
|
|
||||||
|
list_i_s = []
|
||||||
|
list_j_s = []
|
||||||
|
list_x_u = []
|
||||||
|
list_x_d = []
|
||||||
|
list_y_u = []
|
||||||
|
list_y_d = []
|
||||||
|
|
||||||
|
batch_indexer = 0
|
||||||
|
img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3))
|
||||||
|
for i in range(nxf):
|
||||||
|
for j in range(nyf):
|
||||||
|
if i == 0:
|
||||||
|
index_x_d = i * width_mid
|
||||||
|
index_x_u = index_x_d + img_width_model
|
||||||
|
else:
|
||||||
|
index_x_d = i * width_mid
|
||||||
|
index_x_u = index_x_d + img_width_model
|
||||||
|
if j == 0:
|
||||||
|
index_y_d = j * height_mid
|
||||||
|
index_y_u = index_y_d + img_height_model
|
||||||
|
else:
|
||||||
|
index_y_d = j * height_mid
|
||||||
|
index_y_u = index_y_d + img_height_model
|
||||||
|
if index_x_u > img_w:
|
||||||
|
index_x_u = img_w
|
||||||
|
index_x_d = img_w - img_width_model
|
||||||
|
if index_y_u > img_h:
|
||||||
|
index_y_u = img_h
|
||||||
|
index_y_d = img_h - img_height_model
|
||||||
|
|
||||||
|
list_i_s.append(i)
|
||||||
|
list_j_s.append(j)
|
||||||
|
list_x_u.append(index_x_u)
|
||||||
|
list_x_d.append(index_x_d)
|
||||||
|
list_y_d.append(index_y_d)
|
||||||
|
list_y_u.append(index_y_u)
|
||||||
|
|
||||||
|
img_patch[batch_indexer,:,:,:] = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||||
|
batch_indexer += 1
|
||||||
|
|
||||||
|
if (batch_indexer == n_batch_inference or
|
||||||
|
# last batch
|
||||||
|
i == nxf - 1 and j == nyf - 1):
|
||||||
|
self.logger.debug("predicting patches on %s", str(img_patch.shape))
|
||||||
|
label_p_pred = model.predict(img_patch, verbose=0)
|
||||||
|
seg = np.argmax(label_p_pred, axis=3)
|
||||||
|
|
||||||
|
if thresholding_for_some_classes_in_light_version:
|
||||||
|
seg_not_base = label_p_pred[:,:,:,4]
|
||||||
|
seg_not_base[seg_not_base>0.03] =1
|
||||||
|
seg_not_base[seg_not_base<1] =0
|
||||||
|
|
||||||
|
seg_line = label_p_pred[:,:,:,3]
|
||||||
|
seg_line[seg_line>0.1] =1
|
||||||
|
seg_line[seg_line<1] =0
|
||||||
|
|
||||||
|
seg_background = label_p_pred[:,:,:,0]
|
||||||
|
seg_background[seg_background>0.25] =1
|
||||||
|
seg_background[seg_background<1] =0
|
||||||
|
|
||||||
|
seg[seg_not_base==1]=4
|
||||||
|
seg[seg_background==1]=0
|
||||||
|
seg[(seg_line==1) & (seg==0)]=3
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
seg_art = label_p_pred[:,:,:,2]
|
||||||
|
|
||||||
|
seg_art[seg_art<threshold_art_class_textline] = 0
|
||||||
|
seg_art[seg_art>0] =1
|
||||||
|
|
||||||
|
##seg[seg_art==1]=2
|
||||||
|
|
||||||
|
indexer_inside_batch = 0
|
||||||
|
for i_batch, j_batch in zip(list_i_s, list_j_s):
|
||||||
|
seg_in = seg[indexer_inside_batch]
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
seg_in_art = seg_art[indexer_inside_batch]
|
||||||
|
|
||||||
|
index_y_u_in = list_y_u[indexer_inside_batch]
|
||||||
|
index_y_d_in = list_y_d[indexer_inside_batch]
|
||||||
|
|
||||||
|
index_x_u_in = list_x_u[indexer_inside_batch]
|
||||||
|
index_x_d_in = list_x_d[indexer_inside_batch]
|
||||||
|
|
||||||
|
if i_batch == 0 and j_batch == 0:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin] = \
|
||||||
|
seg_in[0:-margin or None,
|
||||||
|
0:-margin or None,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
0:-margin or None]
|
||||||
|
|
||||||
|
elif i_batch == nxf - 1 and j_batch == nyf - 1:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
|
seg_in[margin:,
|
||||||
|
margin:,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
margin:]
|
||||||
|
|
||||||
|
elif i_batch == 0 and j_batch == nyf - 1:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin] = \
|
||||||
|
seg_in[margin:,
|
||||||
|
0:-margin or None,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
0:-margin or None]
|
||||||
|
|
||||||
|
elif i_batch == nxf - 1 and j_batch == 0:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
|
seg_in[0:-margin or None,
|
||||||
|
margin:,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
margin:]
|
||||||
|
|
||||||
|
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin] = \
|
||||||
|
seg_in[margin:-margin or None,
|
||||||
|
0:-margin or None,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + 0:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
0:-margin or None]
|
||||||
|
|
||||||
|
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0] = \
|
||||||
|
seg_in[margin:-margin or None,
|
||||||
|
margin:,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - 0, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
margin:]
|
||||||
|
|
||||||
|
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
|
seg_in[0:-margin or None,
|
||||||
|
margin:-margin or None,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + 0:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[0:-margin or None,
|
||||||
|
margin:-margin or None]
|
||||||
|
|
||||||
|
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
|
seg_in[margin:,
|
||||||
|
margin:-margin or None,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - 0,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:,
|
||||||
|
margin:-margin or None]
|
||||||
|
|
||||||
|
else:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin] = \
|
||||||
|
seg_in[margin:-margin or None,
|
||||||
|
margin:-margin or None,
|
||||||
|
np.newaxis]
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
prediction_true[index_y_d_in + margin:index_y_u_in - margin,
|
||||||
|
index_x_d_in + margin:index_x_u_in - margin, 1] = \
|
||||||
|
seg_in_art[margin:-margin or None,
|
||||||
|
margin:-margin or None]
|
||||||
|
indexer_inside_batch += 1
|
||||||
|
|
||||||
|
|
||||||
|
list_i_s = []
|
||||||
|
list_j_s = []
|
||||||
|
list_x_u = []
|
||||||
|
list_x_d = []
|
||||||
|
list_y_u = []
|
||||||
|
list_y_d = []
|
||||||
|
|
||||||
|
batch_indexer = 0
|
||||||
|
img_patch[:] = 0
|
||||||
|
|
||||||
|
prediction_true = prediction_true.astype(np.uint8)
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class_in_light_version:
|
||||||
|
kernel_min = np.ones((3, 3), np.uint8)
|
||||||
|
prediction_true[:,:,0][prediction_true[:,:,0]==2] = 0
|
||||||
|
|
||||||
|
skeleton_art = skeletonize(prediction_true[:,:,1])
|
||||||
|
skeleton_art = skeleton_art*1
|
||||||
|
|
||||||
|
skeleton_art = skeleton_art.astype('uint8')
|
||||||
|
|
||||||
|
skeleton_art = cv2.dilate(skeleton_art, kernel_min, iterations=1)
|
||||||
|
|
||||||
|
prediction_true[:,:,0][skeleton_art==1]=2
|
||||||
|
#del model
|
||||||
|
gc.collect()
|
||||||
|
return prediction_true
|
||||||
|
|
||||||
|
def run_enhancement(self, light_version):
|
||||||
|
t_in = time.time()
|
||||||
|
self.logger.info("Resizing and enhancing image...")
|
||||||
|
is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = \
|
||||||
|
self.resize_and_enhance_image_with_column_classifier(light_version)
|
||||||
|
|
||||||
|
self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ')
|
||||||
|
return img_res, is_image_enhanced, num_col_classifier, num_column_is_classified
|
||||||
|
|
||||||
|
|
||||||
|
def run_single(self):
|
||||||
|
t0 = time.time()
|
||||||
|
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(light_version=False)
|
||||||
|
|
||||||
|
return img_res
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, image_filename : Optional[str] = None, dir_in : Optional[str] = None, overwrite : bool = False):
|
||||||
|
"""
|
||||||
|
Get image and scales, then extract the page of scanned image
|
||||||
|
"""
|
||||||
|
self.logger.debug("enter run")
|
||||||
|
t0_tot = time.time()
|
||||||
|
|
||||||
|
if dir_in:
|
||||||
|
self.ls_imgs = os.listdir(dir_in)
|
||||||
|
elif image_filename:
|
||||||
|
self.ls_imgs = [image_filename]
|
||||||
|
else:
|
||||||
|
raise ValueError("run requires either a single image filename or a directory")
|
||||||
|
|
||||||
|
for img_filename in self.ls_imgs:
|
||||||
|
self.logger.info(img_filename)
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
self.reset_file_name_dir(os.path.join(dir_in or "", img_filename))
|
||||||
|
#print("text region early -11 in %.1fs", time.time() - t0)
|
||||||
|
|
||||||
|
if os.path.exists(self.output_filename):
|
||||||
|
if overwrite:
|
||||||
|
self.logger.warning("will overwrite existing output file '%s'", self.output_filename)
|
||||||
|
else:
|
||||||
|
self.logger.warning("will skip input for existing output file '%s'", self.output_filename)
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_enhanced = self.run_single()
|
||||||
|
img_enhanced_org_scale = resize_image(image_enhanced, self.h_org, self.w_org)
|
||||||
|
|
||||||
|
cv2.imwrite(self.output_filename, img_enhanced_org_scale)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue