ocr engine first integration

pull/138/head^2
vahidrezanezhad 5 months ago
parent eac18c553d
commit 5144668834

@ -139,6 +139,12 @@ from qurator.eynollah.eynollah import Eynollah
is_flag=True, is_flag=True,
help="if this parameter set to true, this tool would apply machine based reading order detection", help="if this parameter set to true, this tool would apply machine based reading order detection",
) )
@click.option(
"--do_ocr",
"-ocr/-noocr",
is_flag=True,
help="if this parameter set to true, this tool will try to do ocr",
)
@click.option( @click.option(
"--log-level", "--log-level",
"-l", "-l",
@ -167,6 +173,7 @@ def main(
headers_off, headers_off,
light_version, light_version,
reading_order_machine_based, reading_order_machine_based,
do_ocr,
ignore_page_extraction, ignore_page_extraction,
log_level log_level
): ):
@ -205,6 +212,7 @@ def main(
light_version=light_version, light_version=light_version,
ignore_page_extraction=ignore_page_extraction, ignore_page_extraction=ignore_page_extraction,
reading_order_machine_based=reading_order_machine_based, reading_order_machine_based=reading_order_machine_based,
do_ocr=do_ocr,
) )
eynollah.run() eynollah.run()
#pcgts = eynollah.run() #pcgts = eynollah.run()

@ -17,6 +17,16 @@ import gc
from ocrd_utils import getLogger from ocrd_utils import getLogger
import cv2 import cv2
import numpy as np import numpy as np
from transformers import TrOCRProcessor
from PIL import Image
import torch
from difflib import SequenceMatcher as sq
from transformers import VisionEncoderDecoderModel
from numba import cuda
import copy
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
stderr = sys.stderr stderr = sys.stderr
sys.stderr = open(os.devnull, "w") sys.stderr = open(os.devnull, "w")
@ -166,6 +176,7 @@ class Eynollah:
light_version=False, light_version=False,
ignore_page_extraction=False, ignore_page_extraction=False,
reading_order_machine_based=False, reading_order_machine_based=False,
do_ocr=False,
override_dpi=None, override_dpi=None,
logger=None, logger=None,
pcgts=None, pcgts=None,
@ -199,6 +210,7 @@ class Eynollah:
self.headers_off = headers_off self.headers_off = headers_off
self.light_version = light_version self.light_version = light_version
self.ignore_page_extraction = ignore_page_extraction self.ignore_page_extraction = ignore_page_extraction
self.ocr = do_ocr
self.pcgts = pcgts self.pcgts = pcgts
if not dir_in: if not dir_in:
self.plotter = None if not enable_plotting else EynollahPlotter( self.plotter = None if not enable_plotting else EynollahPlotter(
@ -233,6 +245,9 @@ class Eynollah:
self.model_textline_dir = dir_models + "/eynollah-textline_light_20210425" self.model_textline_dir = dir_models + "/eynollah-textline_light_20210425"
else: else:
self.model_textline_dir = dir_models + "/eynollah-textline_20210425" self.model_textline_dir = dir_models + "/eynollah-textline_20210425"
if self.ocr:
self.model_ocr_dir = dir_models + "/checkpoint-166692_printed_trocr"
self.model_tables = dir_models + "/eynollah-tables_20210319" self.model_tables = dir_models + "/eynollah-tables_20210319"
self.models = {} self.models = {}
@ -251,6 +266,10 @@ class Eynollah:
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully) self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_reading_order_machine = self.our_load_model(self.model_reading_order_machine_dir) self.model_reading_order_machine = self.our_load_model(self.model_reading_order_machine_dir)
if self.ocr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")
self.ls_imgs = os.listdir(self.dir_in) self.ls_imgs = os.listdir(self.dir_in)
@ -3135,6 +3154,223 @@ class Eynollah:
return order_of_texts, id_of_texts return order_of_texts, id_of_texts
def return_start_and_end_of_common_text_of_textline_ocr(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.2*width)
width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)
peaks_real, _ = find_peaks(sum_smoothed, height=0)
if len(peaks_real)>70:
print(len(peaks_real), 'len(peaks_real)')
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
arg_sort = np.argsort(sum_smoothed[peaks_real])
arg_sort4 =arg_sort[::-1][:4]
peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
argsort_sorted = np.argsort(peaks_sort_4)
first_4_sorted = peaks_sort_4[argsort_sorted]
y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
#print(first_4_sorted,'first_4_sorted')
arg_sortnew = np.argsort(y_4_sorted)
peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] )
#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peaks_final[0], peaks_final[0]], [0, height-1])
#plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')
return peaks_final[0], peaks_final[1]
else:
pass
def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.06*width)
width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)
peaks_real, _ = find_peaks(sum_smoothed, height=0)
if len(peaks_real)>70:
#print(len(peaks_real), 'len(peaks_real)')
peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]
arg_max = np.argmax(sum_smoothed[peaks_real])
peaks_final = peaks_real[arg_max]
#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peaks_final, peaks_final], [0, height-1])
##plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')
return peaks_final
else:
return None
def return_start_and_end_of_common_text_of_textline_ocr_new_splitted(self,peaks_real, sum_smoothed, start_split, end_split):
peaks_real = peaks_real[(peaks_real<end_split) & (peaks_real>start_split)]
arg_sort = np.argsort(sum_smoothed[peaks_real])
arg_sort4 =arg_sort[::-1][:4]
peaks_sort_4 = peaks_real[arg_sort][::-1][:4]
argsort_sorted = np.argsort(peaks_sort_4)
first_4_sorted = peaks_sort_4[argsort_sorted]
y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
#print(first_4_sorted,'first_4_sorted')
arg_sortnew = np.argsort(y_4_sorted)
peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] )
return peaks_final[0]
def return_start_and_end_of_common_text_of_textline_ocr_new(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.15*width)
width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
mid = int(width/2.)
img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)
peaks_real, _ = find_peaks(sum_smoothed, height=0)
if len(peaks_real)>70:
peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, width1, mid+2)
peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, mid-2, width2)
#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peak_start, peak_start], [0, height-1])
#plt.plot([peak_end, peak_end], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')
return peak_start, peak_end
else:
pass
def return_ocr_of_textline_without_common_section(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
if h2w_ratio > 0.05:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
#width = np.shape(textline_image)[1]
#height = np.shape(textline_image)[0]
#common_window = int(0.3*width)
#width1 = int ( width/2. - common_window )
#width2 = int ( width/2. + common_window )
split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image, ind_tot)
if split_point:
image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height))
image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height))
#pixel_values1 = processor(image1, return_tensors="pt").pixel_values
#pixel_values2 = processor(image2, return_tensors="pt").pixel_values
pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values
generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device))
generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
#print(generated_text_merged,'generated_text_merged')
#generated_ids1 = model_ocr.generate(pixel_values1.to(device))
#generated_ids2 = model_ocr.generate(pixel_values2.to(device))
#generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
#generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
#generated_text = generated_text1 + ' ' + generated_text2
generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1]
#print(generated_text1,'generated_text1')
#print(generated_text2, 'generated_text2')
#print('########################################')
else:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
#print(generated_text,'generated_text')
#print('########################################')
return generated_text
def return_ocr_of_textline(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
if h2w_ratio > 0.05:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
#width = np.shape(textline_image)[1]
#height = np.shape(textline_image)[0]
#common_window = int(0.3*width)
#width1 = int ( width/2. - common_window )
#width2 = int ( width/2. + common_window )
try:
width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot)
image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height))
image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height))
pixel_values1 = processor(image1, return_tensors="pt").pixel_values
pixel_values2 = processor(image2, return_tensors="pt").pixel_values
generated_ids1 = model_ocr.generate(pixel_values1.to(device))
generated_ids2 = model_ocr.generate(pixel_values2.to(device))
generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
#print(generated_text1,'generated_text1')
#print(generated_text2, 'generated_text2')
#print('########################################')
match = sq(None, generated_text1, generated_text2).find_longest_match(0, len(generated_text1), 0, len(generated_text2))
generated_text = generated_text1 + generated_text2[match.b+match.size:]
except:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
def return_textline_contour_with_added_box_coordinate(self, textline_contour, box_ind):
textline_contour[:,0] = textline_contour[:,0] + box_ind[2]
textline_contour[:,1] = textline_contour[:,1] + box_ind[0]
return textline_contour
def run(self): def run(self):
""" """
@ -3398,6 +3634,7 @@ class Eynollah:
if self.plotter: if self.plotter:
self.plotter.write_images_into_directory(polygons_of_images, image_page) self.plotter.write_images_into_directory(polygons_of_images, image_page)
t_order = time.time() t_order = time.time()
if self.full_layout: if self.full_layout:
if self.reading_order_machine_based: if self.reading_order_machine_based:
@ -3425,11 +3662,67 @@ class Eynollah:
contours_only_text_parent_d_ordered = list(np.array(contours_only_text_parent_d_ordered, dtype=object)[index_by_text_par_con]) contours_only_text_parent_d_ordered = list(np.array(contours_only_text_parent_d_ordered, dtype=object)[index_by_text_par_con])
order_text_new, id_of_texts_tot = self.do_order_of_regions(contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d) order_text_new, id_of_texts_tot = self.do_order_of_regions(contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d)
if self.ocr:
device = cuda.get_current_device()
device.reset()
gc.collect()
model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
torch.cuda.empty_cache()
model_ocr.to(device)
ind_tot = 0
#cv2.imwrite('./img_out.png', image_page)
ocr_all_textlines = []
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
ocr_textline_in_textregion = []
for indexing2, ind_poly in enumerate(ind_poly_first):
if not (self.textline_light or self.curved_line):
ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing]
#print(ind_poly,np.shape(ind_poly), 'ind_poly')
#print(box_ind)
ind_poly = self.return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
#print(ind_poly_copy)
ind_poly[ind_poly<0] = 0
x, y, w, h = cv2.boundingRect(ind_poly)
#print(ind_poly_copy, np.shape(ind_poly_copy))
#print(x, y, w, h, h/float(w),'ratio')
h2w_ratio = h/float(w)
mask_poly = np.zeros(image_page.shape)
img_poly_on_img = np.copy(image_page)
mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))
if self.textline_light:
mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1)
img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
text_ocr = self.return_ocr_of_textline_without_common_section(img_croped, model_ocr, processor, device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr)
##cv2.imwrite(str(ind_tot)+'.png', img_croped)
ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion)
else:
ocr_all_textlines = None
#print(ocr_all_textlines)
self.logger.info("detection of reading order took %.1fs", time.time() - t_order) self.logger.info("detection of reading order took %.1fs", time.time() - t_order)
pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables) pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines)
self.logger.info("Job done in %.1fs", time.time() - t0) self.logger.info("Job done in %.1fs", time.time() - t0)
##return pcgts ##return pcgts
self.writer.write_pagexml(pcgts) self.writer.write_pagexml(pcgts)
#self.logger.info("Job done in %.1fs", time.time() - t0) #self.logger.info("Job done in %.1fs", time.time() - t0)
if self.dir_in: if self.dir_in:
self.logger.info("All jobs done in %.1fs", time.time() - t0_tot) self.logger.info("All jobs done in %.1fs", time.time() - t0_tot)

@ -2,7 +2,7 @@
# pylint: disable=import-error # pylint: disable=import-error
from pathlib import Path from pathlib import Path
import os.path import os.path
import xml.etree.ElementTree as ET
from .utils.xml import create_page_xml, xml_reading_order from .utils.xml import create_page_xml, xml_reading_order
from .utils.counter import EynollahIdCounter from .utils.counter import EynollahIdCounter
@ -12,6 +12,7 @@ from ocrd_models.ocrd_page import (
CoordsType, CoordsType,
PcGtsType, PcGtsType,
TextLineType, TextLineType,
TextEquivType,
TextRegionType, TextRegionType,
ImageRegionType, ImageRegionType,
TableRegionType, TableRegionType,
@ -93,11 +94,13 @@ class EynollahXmlWriter():
points_co += ' ' points_co += ' '
coords.set_points(points_co[:-1]) coords.set_points(points_co[:-1])
def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter): def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion):
self.logger.debug('enter serialize_lines_in_region') self.logger.debug('enter serialize_lines_in_region')
for j in range(len(all_found_textline_polygons[region_idx])): for j in range(len(all_found_textline_polygons[region_idx])):
coords = CoordsType() coords = CoordsType()
textline = TextLineType(id=counter.next_line_id, Coords=coords) textline = TextLineType(id=counter.next_line_id, Coords=coords)
if ocr_all_textlines_textregion:
textline.set_TextEquiv( [ TextEquivType(Unicode=ocr_all_textlines_textregion[j]) ] )
text_region.add_TextLine(textline) text_region.add_TextLine(textline)
region_bboxes = all_box_coord[region_idx] region_bboxes = all_box_coord[region_idx]
points_co = '' points_co = ''
@ -140,7 +143,7 @@ class EynollahXmlWriter():
with open(out_fname, 'w') as f: with open(out_fname, 'w') as f:
f.write(to_xml(pcgts)) f.write(to_xml(pcgts))
def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables): def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines):
self.logger.debug('enter build_pagexml_no_full_layout') self.logger.debug('enter build_pagexml_no_full_layout')
# create the file structure # create the file structure
@ -159,7 +162,11 @@ class EynollahXmlWriter():
Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)), Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)),
) )
page.add_TextRegion(textregion) page.add_TextRegion(textregion)
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter) if ocr_all_textlines:
ocr_textlines = ocr_all_textlines[mm]
else:
ocr_textlines = None
self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter, ocr_textlines)
for mm in range(len(found_polygons_marginals)): for mm in range(len(found_polygons_marginals)):
marginal = TextRegionType(id=counter.next_region_id, type_='marginalia', marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',

Loading…
Cancel
Save