ocr engine first integration

pull/138/head^2
vahidrezanezhad 6 months ago
parent eac18c553d
commit 5144668834

@ -139,6 +139,12 @@ from qurator.eynollah.eynollah import Eynollah
is_flag=True,
help="if this parameter set to true, this tool would apply machine based reading order detection",
)
@click.option(
"--do_ocr",
"-ocr/-noocr",
is_flag=True,
help="if this parameter set to true, this tool will try to do ocr",
)
@click.option(
"--log-level",
"-l",
@ -167,6 +173,7 @@ def main(
headers_off,
light_version,
reading_order_machine_based,
do_ocr,
ignore_page_extraction,
log_level
):
@ -205,6 +212,7 @@ def main(
light_version=light_version,
ignore_page_extraction=ignore_page_extraction,
reading_order_machine_based=reading_order_machine_based,
do_ocr=do_ocr,
)
eynollah.run()
#pcgts = eynollah.run()

@ -17,6 +17,16 @@ import gc
from ocrd_utils import getLogger
import cv2
import numpy as np
from transformers import TrOCRProcessor
from PIL import Image
import torch
from difflib import SequenceMatcher as sq
from transformers import VisionEncoderDecoderModel
from numba import cuda
import copy
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
stderr = sys.stderr
sys.stderr = open(os.devnull, "w")
@ -166,6 +176,7 @@ class Eynollah:
light_version=False,
ignore_page_extraction=False,
reading_order_machine_based=False,
do_ocr=False,
override_dpi=None,
logger=None,
pcgts=None,
@ -199,6 +210,7 @@ class Eynollah:
self.headers_off = headers_off
self.light_version = light_version
self.ignore_page_extraction = ignore_page_extraction
self.ocr = do_ocr
self.pcgts = pcgts
if not dir_in:
self.plotter = None if not enable_plotting else EynollahPlotter(
@ -233,6 +245,9 @@ class Eynollah:
self.model_textline_dir = dir_models + "/eynollah-textline_light_20210425"
else:
self.model_textline_dir = dir_models + "/eynollah-textline_20210425"
if self.ocr:
self.model_ocr_dir = dir_models + "/checkpoint-166692_printed_trocr"
self.model_tables = dir_models + "/eynollah-tables_20210319"
self.models = {}
@ -251,6 +266,10 @@ class Eynollah:
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_reading_order_machine = self.our_load_model(self.model_reading_order_machine_dir)
if self.ocr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")
self.ls_imgs = os.listdir(self.dir_in)
@ -3135,6 +3154,223 @@ class Eynollah:
return order_of_texts, id_of_texts
def return_start_and_end_of_common_text_of_textline_ocr(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.2*width)
width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)
peaks_real, _ = find_peaks(sum_smoothed, height=0)
if len(peaks_real)>70:
print(len(peaks_real), 'len(peaks_real)')
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
arg_sort = np.argsort(sum_smoothed[peaks_real])
arg_sort4 =arg_sort[::-1][:4]
peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
argsort_sorted = np.argsort(peaks_sort_4)
first_4_sorted = peaks_sort_4[argsort_sorted]
y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
#print(first_4_sorted,'first_4_sorted')
arg_sortnew = np.argsort(y_4_sorted)
peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] )
#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peaks_final[0], peaks_final[0]], [0, height-1])
#plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')
return peaks_final[0], peaks_final[1]
else:
pass
def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.06*width)
width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)
peaks_real, _ = find_peaks(sum_smoothed, height=0)
if len(peaks_real)>70:
#print(len(peaks_real), 'len(peaks_real)')
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
arg_max = np.argmax(sum_smoothed[peaks_real])
peaks_final = peaks_real[arg_max]
#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peaks_final, peaks_final], [0, height-1])
##plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')
return peaks_final
else:
return None
def return_start_and_end_of_common_text_of_textline_ocr_new_splitted(self,peaks_real, sum_smoothed, start_split, end_split):
peaks_real = peaks_real[(peaks_real<end_split) & (peaks_real>start_split)]
arg_sort = np.argsort(sum_smoothed[peaks_real])
arg_sort4 =arg_sort[::-1][:4]
peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
argsort_sorted = np.argsort(peaks_sort_4)
first_4_sorted = peaks_sort_4[argsort_sorted]
y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
#print(first_4_sorted,'first_4_sorted')
arg_sortnew = np.argsort(y_4_sorted)
peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] )
return peaks_final[0]
def return_start_and_end_of_common_text_of_textline_ocr_new(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.15*width)
width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
mid = int(width/2.)
img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)
peaks_real, _ = find_peaks(sum_smoothed, height=0)
if len(peaks_real)>70:
peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, width1, mid+2)
peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, mid-2, width2)
#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peak_start, peak_start], [0, height-1])
#plt.plot([peak_end, peak_end], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')
return peak_start, peak_end
else:
pass
def return_ocr_of_textline_without_common_section(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
if h2w_ratio > 0.05:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
#width = np.shape(textline_image)[1]
#height = np.shape(textline_image)[0]
#common_window = int(0.3*width)
#width1 = int ( width/2. - common_window )
#width2 = int ( width/2. + common_window )
split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image, ind_tot)
if split_point:
image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height))
image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height))
#pixel_values1 = processor(image1, return_tensors="pt").pixel_values
#pixel_values2 = processor(image2, return_tensors="pt").pixel_values
pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values
generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device))
generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
#print(generated_text_merged,'generated_text_merged')
#generated_ids1 = model_ocr.generate(pixel_values1.to(device))
#generated_ids2 = model_ocr.generate(pixel_values2.to(device))
#generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
#generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
#generated_text = generated_text1 + ' ' + generated_text2
generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1]
#print(generated_text1,'generated_text1')
#print(generated_text2, 'generated_text2')
#print('########################################')
else:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
#print(generated_text,'generated_text')
#print('########################################')
return generated_text
def return_ocr_of_textline(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
if h2w_ratio > 0.05:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
#width = np.shape(textline_image)[1]
#height = np.shape(textline_image)[0]
#common_window = int(0.3*width)
#width1 = int ( width/2. - common_window )
#width2 = int ( width/2. + common_window )
try:
width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot)
image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height))
image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height))
pixel_values1 = processor(image1, return_tensors="pt").pixel_values
pixel_values2 = processor(image2, return_tensors="pt").pixel_values
generated_ids1 = model_ocr.generate(pixel_values1.to(device))
generated_ids2 = model_ocr.generate(pixel_values2.to(device))
generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
#print(generated_text1,'generated_text1')
#print(generated_text2, 'generated_text2')
#print('########################################')
match = sq(None, generated_text1, generated_text2).find_longest_match(0, len(generated_text1), 0, len(generated_text2))
generated_text = generated_text1 + generated_text2[match.b+match.size:]
except:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
def return_textline_contour_with_added_box_coordinate(self, textline_contour, box_ind):
textline_contour[:,0] = textline_contour[:,0] + box_ind[2]
textline_contour[:,1] = textline_contour[:,1] + box_ind[0]
return textline_contour
def run(self):
"""
@ -3398,6 +3634,7 @@ class Eynollah:
if self.plotter:
self.plotter.write_images_into_directory(polygons_of_images, image_page)
t_order = time.time()
if self.full_layout:
if self.reading_order_machine_based:
@ -3425,11 +3662,67 @@ class Eynollah:
contours_only_text_parent_d_ordered = list(np.array(contours_only_text_parent_d_ordered, dtype=object)[index_by_text_par_con])
order_text_new, id_of_texts_tot = self.do_order_of_regions(contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d)
if self.ocr:
device = cuda.get_current_device()
device.reset()
gc.collect()
model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
torch.cuda.empty_cache()
model_ocr.to(device)
ind_tot = 0
#cv2.imwrite('./img_out.png', image_page)
ocr_all_textlines = []
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
ocr_textline_in_textregion = []
for indexing2, ind_poly in enumerate(ind_poly_first):
if not (self.textline_light or self.curved_line):
ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing]
#print(ind_poly,np.shape(ind_poly), 'ind_poly')
#print(box_ind)
ind_poly = self.return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
#print(ind_poly_copy)
ind_poly[ind_poly<0] = 0
x, y, w, h = cv2.boundingRect(ind_poly)
#print(ind_poly_copy, np.shape(ind_poly_copy))
#print(x, y, w, h, h/float(w),'ratio')
h2w_ratio = h/float(w)
mask_poly = np.zeros(image_page.shape)
img_poly_on_img = np.copy(image_page)
mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))
if self.textline_light:
mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1)
img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
text_ocr = self.return_ocr_of_textline_without_common_section(img_croped, model_ocr, processor, device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr)
##cv2.imwrite(str(ind_tot)+'.png', img_croped)
ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion)
else:
ocr_all_textlines = None
#print(ocr_all_textlines)
self.logger.info("detection of reading order took %.1fs", time.time() - t_order)
pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables)
pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines)
self.logger.info("Job done in %.1fs", time.time() - t0)
##return pcgts
self.writer.write_pagexml(pcgts)
#self.logger.info("Job done in %.1fs", time.time() - t0)
if self.dir_in:
self.logger.info("All jobs done in %.1fs", time.time() - t0_tot)

@ -2,7 +2,7 @@
# pylint: disable=import-error
from pathlib import Path
import os.path
import xml.etree.ElementTree as ET
from .utils.xml import create_page_xml, xml_reading_order
from .utils.counter import EynollahIdCounter
@ -12,6 +12,7 @@ from ocrd_models.ocrd_page import (
CoordsType,
PcGtsType,
TextLineType,
TextEquivType,
TextRegionType,
ImageRegionType,
TableRegionType,
@ -93,11 +94,13 @@ class EynollahXmlWriter():
points_co += ' '
coords.set_points(points_co[:-1])
def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter):
def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion):
self.logger.debug('enter serialize_lines_in_region')
for j in range(len(all_found_textline_polygons[region_idx])):
coords = CoordsType()
textline = TextLineType(id=counter.next_line_id, Coords=coords)
if ocr_all_textlines_textregion:
textline.set_TextEquiv( [ TextEquivType(Unicode=ocr_all_textlines_textregion[j]) ] )
text_region.add_TextLine(textline)
region_bboxes = all_box_coord[region_idx]
points_co = ''
@ -140,7 +143,7 @@ class EynollahXmlWriter():
with open(out_fname, 'w') as f:
f.write(to_xml(pcgts))
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables):
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines):
self.logger.debug('enter build_pagexml_no_full_layout')
# create the file structure
@ -159,7 +162,11 @@ class EynollahXmlWriter():
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)),
)
page.add_TextRegion(textregion)
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter)
if ocr_all_textlines:
ocr_textlines = ocr_all_textlines[mm]
else:
ocr_textlines = None
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter, ocr_textlines)
for mm in range(len(found_polygons_marginals)):
marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',

Loading…
Cancel
Save