|
|
|
@ -17,6 +17,16 @@ import gc
|
|
|
|
|
from ocrd_utils import getLogger
|
|
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
from transformers import TrOCRProcessor
|
|
|
|
|
from PIL import Image
|
|
|
|
|
import torch
|
|
|
|
|
from difflib import SequenceMatcher as sq
|
|
|
|
|
from transformers import VisionEncoderDecoderModel
|
|
|
|
|
from numba import cuda
|
|
|
|
|
import copy
|
|
|
|
|
from scipy.signal import find_peaks
|
|
|
|
|
from scipy.ndimage import gaussian_filter1d
|
|
|
|
|
|
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
|
stderr = sys.stderr
|
|
|
|
|
sys.stderr = open(os.devnull, "w")
|
|
|
|
@ -166,6 +176,7 @@ class Eynollah:
|
|
|
|
|
light_version=False,
|
|
|
|
|
ignore_page_extraction=False,
|
|
|
|
|
reading_order_machine_based=False,
|
|
|
|
|
do_ocr=False,
|
|
|
|
|
override_dpi=None,
|
|
|
|
|
logger=None,
|
|
|
|
|
pcgts=None,
|
|
|
|
@ -199,6 +210,7 @@ class Eynollah:
|
|
|
|
|
self.headers_off = headers_off
|
|
|
|
|
self.light_version = light_version
|
|
|
|
|
self.ignore_page_extraction = ignore_page_extraction
|
|
|
|
|
self.ocr = do_ocr
|
|
|
|
|
self.pcgts = pcgts
|
|
|
|
|
if not dir_in:
|
|
|
|
|
self.plotter = None if not enable_plotting else EynollahPlotter(
|
|
|
|
@ -233,6 +245,9 @@ class Eynollah:
|
|
|
|
|
self.model_textline_dir = dir_models + "/eynollah-textline_light_20210425"
|
|
|
|
|
else:
|
|
|
|
|
self.model_textline_dir = dir_models + "/eynollah-textline_20210425"
|
|
|
|
|
if self.ocr:
|
|
|
|
|
self.model_ocr_dir = dir_models + "/checkpoint-166692_printed_trocr"
|
|
|
|
|
|
|
|
|
|
self.model_tables = dir_models + "/eynollah-tables_20210319"
|
|
|
|
|
|
|
|
|
|
self.models = {}
|
|
|
|
@ -251,6 +266,10 @@ class Eynollah:
|
|
|
|
|
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
|
|
|
|
|
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
|
|
|
|
|
self.model_reading_order_machine = self.our_load_model(self.model_reading_order_machine_dir)
|
|
|
|
|
if self.ocr:
|
|
|
|
|
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
|
|
|
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")
|
|
|
|
|
|
|
|
|
|
self.ls_imgs = os.listdir(self.dir_in)
|
|
|
|
|
|
|
|
|
@ -3135,6 +3154,223 @@ class Eynollah:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return order_of_texts, id_of_texts
|
|
|
|
|
def return_start_and_end_of_common_text_of_textline_ocr(self,textline_image, ind_tot):
|
|
|
|
|
width = np.shape(textline_image)[1]
|
|
|
|
|
height = np.shape(textline_image)[0]
|
|
|
|
|
common_window = int(0.2*width)
|
|
|
|
|
|
|
|
|
|
width1 = int ( width/2. - common_window )
|
|
|
|
|
width2 = int ( width/2. + common_window )
|
|
|
|
|
|
|
|
|
|
img_sum = np.sum(textline_image[:,:,0], axis=0)
|
|
|
|
|
sum_smoothed = gaussian_filter1d(img_sum, 3)
|
|
|
|
|
|
|
|
|
|
peaks_real, _ = find_peaks(sum_smoothed, height=0)
|
|
|
|
|
|
|
|
|
|
if len(peaks_real)>70:
|
|
|
|
|
print(len(peaks_real), 'len(peaks_real)')
|
|
|
|
|
|
|
|
|
|
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
|
|
|
|
|
|
|
|
|
|
arg_sort = np.argsort(sum_smoothed[peaks_real])
|
|
|
|
|
|
|
|
|
|
arg_sort4 =arg_sort[::-1][:4]
|
|
|
|
|
|
|
|
|
|
peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
|
|
|
|
|
|
|
|
|
|
argsort_sorted = np.argsort(peaks_sort_4)
|
|
|
|
|
|
|
|
|
|
first_4_sorted = peaks_sort_4[argsort_sorted]
|
|
|
|
|
y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
|
|
|
|
|
#print(first_4_sorted,'first_4_sorted')
|
|
|
|
|
|
|
|
|
|
arg_sortnew = np.argsort(y_4_sorted)
|
|
|
|
|
peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] )
|
|
|
|
|
|
|
|
|
|
#plt.figure(ind_tot)
|
|
|
|
|
#plt.imshow(textline_image)
|
|
|
|
|
#plt.plot([peaks_final[0], peaks_final[0]], [0, height-1])
|
|
|
|
|
#plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
|
|
|
|
|
#plt.savefig('./'+str(ind_tot)+'.png')
|
|
|
|
|
|
|
|
|
|
return peaks_final[0], peaks_final[1]
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self,textline_image, ind_tot):
|
|
|
|
|
width = np.shape(textline_image)[1]
|
|
|
|
|
height = np.shape(textline_image)[0]
|
|
|
|
|
common_window = int(0.06*width)
|
|
|
|
|
|
|
|
|
|
width1 = int ( width/2. - common_window )
|
|
|
|
|
width2 = int ( width/2. + common_window )
|
|
|
|
|
|
|
|
|
|
img_sum = np.sum(textline_image[:,:,0], axis=0)
|
|
|
|
|
sum_smoothed = gaussian_filter1d(img_sum, 3)
|
|
|
|
|
|
|
|
|
|
peaks_real, _ = find_peaks(sum_smoothed, height=0)
|
|
|
|
|
|
|
|
|
|
if len(peaks_real)>70:
|
|
|
|
|
#print(len(peaks_real), 'len(peaks_real)')
|
|
|
|
|
|
|
|
|
|
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
|
|
|
|
|
|
|
|
|
|
arg_max = np.argmax(sum_smoothed[peaks_real])
|
|
|
|
|
|
|
|
|
|
peaks_final = peaks_real[arg_max]
|
|
|
|
|
|
|
|
|
|
#plt.figure(ind_tot)
|
|
|
|
|
#plt.imshow(textline_image)
|
|
|
|
|
#plt.plot([peaks_final, peaks_final], [0, height-1])
|
|
|
|
|
##plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
|
|
|
|
|
#plt.savefig('./'+str(ind_tot)+'.png')
|
|
|
|
|
|
|
|
|
|
return peaks_final
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
def return_start_and_end_of_common_text_of_textline_ocr_new_splitted(self,peaks_real, sum_smoothed, start_split, end_split):
|
|
|
|
|
peaks_real = peaks_real[(peaks_real<end_split) & (peaks_real>start_split)]
|
|
|
|
|
|
|
|
|
|
arg_sort = np.argsort(sum_smoothed[peaks_real])
|
|
|
|
|
|
|
|
|
|
arg_sort4 =arg_sort[::-1][:4]
|
|
|
|
|
|
|
|
|
|
peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
|
|
|
|
|
|
|
|
|
|
argsort_sorted = np.argsort(peaks_sort_4)
|
|
|
|
|
|
|
|
|
|
first_4_sorted = peaks_sort_4[argsort_sorted]
|
|
|
|
|
y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
|
|
|
|
|
#print(first_4_sorted,'first_4_sorted')
|
|
|
|
|
|
|
|
|
|
arg_sortnew = np.argsort(y_4_sorted)
|
|
|
|
|
peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] )
|
|
|
|
|
return peaks_final[0]
|
|
|
|
|
|
|
|
|
|
def return_start_and_end_of_common_text_of_textline_ocr_new(self,textline_image, ind_tot):
|
|
|
|
|
width = np.shape(textline_image)[1]
|
|
|
|
|
height = np.shape(textline_image)[0]
|
|
|
|
|
common_window = int(0.15*width)
|
|
|
|
|
|
|
|
|
|
width1 = int ( width/2. - common_window )
|
|
|
|
|
width2 = int ( width/2. + common_window )
|
|
|
|
|
mid = int(width/2.)
|
|
|
|
|
|
|
|
|
|
img_sum = np.sum(textline_image[:,:,0], axis=0)
|
|
|
|
|
sum_smoothed = gaussian_filter1d(img_sum, 3)
|
|
|
|
|
|
|
|
|
|
peaks_real, _ = find_peaks(sum_smoothed, height=0)
|
|
|
|
|
|
|
|
|
|
if len(peaks_real)>70:
|
|
|
|
|
peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, width1, mid+2)
|
|
|
|
|
|
|
|
|
|
peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, mid-2, width2)
|
|
|
|
|
|
|
|
|
|
#plt.figure(ind_tot)
|
|
|
|
|
#plt.imshow(textline_image)
|
|
|
|
|
#plt.plot([peak_start, peak_start], [0, height-1])
|
|
|
|
|
#plt.plot([peak_end, peak_end], [0, height-1])
|
|
|
|
|
#plt.savefig('./'+str(ind_tot)+'.png')
|
|
|
|
|
|
|
|
|
|
return peak_start, peak_end
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def return_ocr_of_textline_without_common_section(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
|
|
|
|
|
if h2w_ratio > 0.05:
|
|
|
|
|
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
|
|
|
|
generated_ids = model_ocr.generate(pixel_values.to(device))
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
#width = np.shape(textline_image)[1]
|
|
|
|
|
#height = np.shape(textline_image)[0]
|
|
|
|
|
#common_window = int(0.3*width)
|
|
|
|
|
|
|
|
|
|
#width1 = int ( width/2. - common_window )
|
|
|
|
|
#width2 = int ( width/2. + common_window )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image, ind_tot)
|
|
|
|
|
if split_point:
|
|
|
|
|
image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height))
|
|
|
|
|
image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height))
|
|
|
|
|
|
|
|
|
|
#pixel_values1 = processor(image1, return_tensors="pt").pixel_values
|
|
|
|
|
#pixel_values2 = processor(image2, return_tensors="pt").pixel_values
|
|
|
|
|
|
|
|
|
|
pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values
|
|
|
|
|
generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device))
|
|
|
|
|
generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
#print(generated_text_merged,'generated_text_merged')
|
|
|
|
|
|
|
|
|
|
#generated_ids1 = model_ocr.generate(pixel_values1.to(device))
|
|
|
|
|
#generated_ids2 = model_ocr.generate(pixel_values2.to(device))
|
|
|
|
|
|
|
|
|
|
#generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
|
|
|
|
|
#generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
|
|
|
|
|
|
|
|
|
|
#generated_text = generated_text1 + ' ' + generated_text2
|
|
|
|
|
generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1]
|
|
|
|
|
|
|
|
|
|
#print(generated_text1,'generated_text1')
|
|
|
|
|
#print(generated_text2, 'generated_text2')
|
|
|
|
|
#print('########################################')
|
|
|
|
|
else:
|
|
|
|
|
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
|
|
|
|
generated_ids = model_ocr.generate(pixel_values.to(device))
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
|
|
|
|
|
|
#print(generated_text,'generated_text')
|
|
|
|
|
#print('########################################')
|
|
|
|
|
return generated_text
|
|
|
|
|
def return_ocr_of_textline(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
|
|
|
|
|
if h2w_ratio > 0.05:
|
|
|
|
|
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
|
|
|
|
generated_ids = model_ocr.generate(pixel_values.to(device))
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
|
else:
|
|
|
|
|
#width = np.shape(textline_image)[1]
|
|
|
|
|
#height = np.shape(textline_image)[0]
|
|
|
|
|
#common_window = int(0.3*width)
|
|
|
|
|
|
|
|
|
|
#width1 = int ( width/2. - common_window )
|
|
|
|
|
#width2 = int ( width/2. + common_window )
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot)
|
|
|
|
|
|
|
|
|
|
image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height))
|
|
|
|
|
image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height))
|
|
|
|
|
|
|
|
|
|
pixel_values1 = processor(image1, return_tensors="pt").pixel_values
|
|
|
|
|
pixel_values2 = processor(image2, return_tensors="pt").pixel_values
|
|
|
|
|
|
|
|
|
|
generated_ids1 = model_ocr.generate(pixel_values1.to(device))
|
|
|
|
|
generated_ids2 = model_ocr.generate(pixel_values2.to(device))
|
|
|
|
|
|
|
|
|
|
generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
|
|
|
|
|
generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
|
|
|
|
|
#print(generated_text1,'generated_text1')
|
|
|
|
|
#print(generated_text2, 'generated_text2')
|
|
|
|
|
#print('########################################')
|
|
|
|
|
|
|
|
|
|
match = sq(None, generated_text1, generated_text2).find_longest_match(0, len(generated_text1), 0, len(generated_text2))
|
|
|
|
|
|
|
|
|
|
generated_text = generated_text1 + generated_text2[match.b+match.size:]
|
|
|
|
|
except:
|
|
|
|
|
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
|
|
|
|
|
generated_ids = model_ocr.generate(pixel_values.to(device))
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
|
|
|
|
|
|
return generated_text
|
|
|
|
|
|
|
|
|
|
def return_textline_contour_with_added_box_coordinate(self, textline_contour, box_ind):
|
|
|
|
|
textline_contour[:,0] = textline_contour[:,0] + box_ind[2]
|
|
|
|
|
textline_contour[:,1] = textline_contour[:,1] + box_ind[0]
|
|
|
|
|
return textline_contour
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
"""
|
|
|
|
@ -3398,6 +3634,7 @@ class Eynollah:
|
|
|
|
|
if self.plotter:
|
|
|
|
|
self.plotter.write_images_into_directory(polygons_of_images, image_page)
|
|
|
|
|
t_order = time.time()
|
|
|
|
|
|
|
|
|
|
if self.full_layout:
|
|
|
|
|
|
|
|
|
|
if self.reading_order_machine_based:
|
|
|
|
@ -3425,11 +3662,67 @@ class Eynollah:
|
|
|
|
|
contours_only_text_parent_d_ordered = list(np.array(contours_only_text_parent_d_ordered, dtype=object)[index_by_text_par_con])
|
|
|
|
|
order_text_new, id_of_texts_tot = self.do_order_of_regions(contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.ocr:
|
|
|
|
|
|
|
|
|
|
device = cuda.get_current_device()
|
|
|
|
|
device.reset()
|
|
|
|
|
gc.collect()
|
|
|
|
|
model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
model_ocr.to(device)
|
|
|
|
|
|
|
|
|
|
ind_tot = 0
|
|
|
|
|
#cv2.imwrite('./img_out.png', image_page)
|
|
|
|
|
|
|
|
|
|
ocr_all_textlines = []
|
|
|
|
|
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
|
|
|
|
|
ocr_textline_in_textregion = []
|
|
|
|
|
for indexing2, ind_poly in enumerate(ind_poly_first):
|
|
|
|
|
if not (self.textline_light or self.curved_line):
|
|
|
|
|
ind_poly = copy.deepcopy(ind_poly)
|
|
|
|
|
box_ind = all_box_coord[indexing]
|
|
|
|
|
#print(ind_poly,np.shape(ind_poly), 'ind_poly')
|
|
|
|
|
#print(box_ind)
|
|
|
|
|
ind_poly = self.return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
|
|
|
|
|
#print(ind_poly_copy)
|
|
|
|
|
ind_poly[ind_poly<0] = 0
|
|
|
|
|
x, y, w, h = cv2.boundingRect(ind_poly)
|
|
|
|
|
#print(ind_poly_copy, np.shape(ind_poly_copy))
|
|
|
|
|
#print(x, y, w, h, h/float(w),'ratio')
|
|
|
|
|
h2w_ratio = h/float(w)
|
|
|
|
|
mask_poly = np.zeros(image_page.shape)
|
|
|
|
|
img_poly_on_img = np.copy(image_page)
|
|
|
|
|
|
|
|
|
|
mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))
|
|
|
|
|
|
|
|
|
|
if self.textline_light:
|
|
|
|
|
mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1)
|
|
|
|
|
|
|
|
|
|
img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255
|
|
|
|
|
img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255
|
|
|
|
|
img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255
|
|
|
|
|
|
|
|
|
|
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
|
|
|
|
|
text_ocr = self.return_ocr_of_textline_without_common_section(img_croped, model_ocr, processor, device, w, h2w_ratio, ind_tot)
|
|
|
|
|
|
|
|
|
|
ocr_textline_in_textregion.append(text_ocr)
|
|
|
|
|
|
|
|
|
|
##cv2.imwrite(str(ind_tot)+'.png', img_croped)
|
|
|
|
|
ind_tot = ind_tot +1
|
|
|
|
|
ocr_all_textlines.append(ocr_textline_in_textregion)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
ocr_all_textlines = None
|
|
|
|
|
#print(ocr_all_textlines)
|
|
|
|
|
self.logger.info("detection of reading order took %.1fs", time.time() - t_order)
|
|
|
|
|
pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables)
|
|
|
|
|
pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines)
|
|
|
|
|
self.logger.info("Job done in %.1fs", time.time() - t0)
|
|
|
|
|
##return pcgts
|
|
|
|
|
self.writer.write_pagexml(pcgts)
|
|
|
|
|
#self.logger.info("Job done in %.1fs", time.time() - t0)
|
|
|
|
|
|
|
|
|
|
if self.dir_in:
|
|
|
|
|
self.logger.info("All jobs done in %.1fs", time.time() - t0_tot)
|
|
|
|
|