🔥 refactor eynollah ocr

.
This commit is contained in:
kba 2025-11-28 14:54:43 +01:00
parent 30f9c695dc
commit b161e33854
5 changed files with 769 additions and 865 deletions

View file

@ -88,7 +88,6 @@ def ocr_cli(
tr_ocr, tr_ocr,
do_not_mask_with_textline_contour, do_not_mask_with_textline_contour,
batch_size, batch_size,
dataset_abbrevation,
min_conf_value_of_textline_text, min_conf_value_of_textline_text,
): ):
""" """
@ -101,7 +100,6 @@ def ocr_cli(
tr_ocr=tr_ocr, tr_ocr=tr_ocr,
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
batch_size=batch_size, batch_size=batch_size,
pref_of_dataset=dataset_abbrevation,
min_conf_value_of_textline_text=min_conf_value_of_textline_text) min_conf_value_of_textline_text=min_conf_value_of_textline_text)
eynollah_ocr.run(overwrite=overwrite, eynollah_ocr.run(overwrite=overwrite,
dir_in=dir_in, dir_in=dir_in,

View file

@ -1,24 +1,22 @@
# FIXME: fix all of those... # FIXME: fix all of those...
# pyright: reportPossiblyUnboundVariable=false
# pyright: reportOptionalMemberAccess=false
# pyright: reportArgumentType=false
# pyright: reportCallIssue=false
# pyright: reportOptionalSubscript=false # pyright: reportOptionalSubscript=false
from logging import Logger, getLogger from logging import Logger, getLogger
from typing import Optional from typing import List, Optional
from pathlib import Path from pathlib import Path
import os import os
import gc import gc
import sys
import math import math
import time from dataclasses import dataclass
import cv2 import cv2
import xml.etree.ElementTree as ET from cv2.typing import MatLike
from PIL import Image, ImageDraw, ImageFont from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw
import numpy as np import numpy as np
from eynollah.model_zoo import EynollahModelZoo from eynollah.model_zoo import EynollahModelZoo
from eynollah.utils.font import get_font
from eynollah.utils.xml import etree_namespace_for_element_tag
try: try:
import torch import torch
except ImportError: except ImportError:
@ -38,11 +36,13 @@ from .utils.utils_ocr import (
rotate_image_with_padding, rotate_image_with_padding,
) )
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files # TODO: refine typing
if sys.version_info < (3, 10): @dataclass
import importlib_resources class EynollahOcrResult:
else: extracted_texts_merged: List
import importlib.resources as importlib_resources extracted_conf_value_merged: Optional[List]
cropped_lines_region_indexer: List
total_bb_coordinates:List
class Eynollah_ocr: class Eynollah_ocr:
def __init__( def __init__(
@ -76,6 +76,7 @@ class Eynollah_ocr:
@property @property
def device(self): def device(self):
assert torch
if torch.cuda.is_available(): if torch.cuda.is_available():
self.logger.info("Using GPU acceleration") self.logger.info("Using GPU acceleration")
return torch.device("cuda:0") return torch.device("cuda:0")
@ -83,59 +84,17 @@ class Eynollah_ocr:
self.logger.info("Using CPU processing") self.logger.info("Using CPU processing")
return torch.device("cpu") return torch.device("cpu")
def run(self, overwrite: bool = False, def run_trocr(
dir_in: Optional[str] = None, self,
# Prediction with RGB and binarized images for selected pages, should not be the default *,
dir_in_bin: Optional[str] = None, img: MatLike,
image_filename: Optional[str] = None, page_tree: ET.ElementTree,
dir_xmls: Optional[str] = None, page_ns,
dir_out_image_text: Optional[str] = None, tr_ocr_input_height_and_width,
dir_out: Optional[str] = None, ) -> EynollahOcrResult:
):
if dir_in:
ls_imgs = [os.path.join(dir_in, image_filename)
for image_filename in filter(is_image_filename,
os.listdir(dir_in))]
else:
assert image_filename
ls_imgs = [image_filename]
if self.tr_ocr:
tr_ocr_input_height_and_width = 384
for dir_img in ls_imgs:
file_name = Path(dir_img).stem
assert dir_xmls # FIXME: check the logic
dir_xml = os.path.join(dir_xmls, file_name+'.xml')
assert dir_out # FIXME: check the logic
out_file_ocr = os.path.join(dir_out, file_name+'.xml')
if os.path.exists(out_file_ocr):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", out_file_ocr)
else:
self.logger.warning("will skip input for existing output file '%s'", out_file_ocr)
continue
img = cv2.imread(dir_img)
if dir_out_image_text:
out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png')
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text)
total_bb_coordinates = [] total_bb_coordinates = []
##file_name = Path(dir_xmls).stem
tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8"))
root1=tree1.getroot()
alltags=[elem.tag for elem in root1.iter()]
link=alltags[0].split('}')[0]+'}'
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')])
cropped_lines = [] cropped_lines = []
cropped_lines_region_indexer = [] cropped_lines_region_indexer = []
@ -146,7 +105,7 @@ class Eynollah_ocr:
indexer_text_region = 0 indexer_text_region = 0
indexer_b_s = 0 indexer_b_s = 0
for nn in root1.iter(region_tags): for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
for child_textregion in nn: for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"): if child_textregion.tag.endswith("TextLine"):
@ -159,7 +118,6 @@ class Eynollah_ocr:
for x in p_h] ) for x in p_h] )
x,y,w,h = cv2.boundingRect(textline_coords) x,y,w,h = cv2.boundingRect(textline_coords)
if dir_out_image_text:
total_bb_coordinates.append([x,y,w,h]) total_bb_coordinates.append([x,y,w,h])
h2w_ratio = h/float(w) h2w_ratio = h/float(w)
@ -301,185 +259,37 @@ class Eynollah_ocr:
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
#print(extracted_texts_merged, len(extracted_texts_merged)) #print(extracted_texts_merged, len(extracted_texts_merged))
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) return EynollahOcrResult(
extracted_texts_merged=extracted_texts_merged,
extracted_conf_value_merged=None,
cropped_lines_region_indexer=cropped_lines_region_indexer,
total_bb_coordinates=total_bb_coordinates,
)
if dir_out_image_text: def run_cnn(
self,
*,
img: MatLike,
img_bin: Optional[MatLike],
page_tree: ET.ElementTree,
page_ns,
image_width,
image_height,
) -> EynollahOcrResult:
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "Charis-Regular.ttf"
with importlib_resources.as_file(font) as font:
font = ImageFont.truetype(font=font, size=40)
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0]
y_bb = bb_ind[1]
w_bb = bb_ind[2]
h_bb = bb_ind[3]
font = fit_text_single_line(draw, extracted_texts_merged[indexer_text],
font.path, w_bb, int(h_bb*0.4) )
##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2)
text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally
text_y = y_bb + (h_bb - text_height) // 2 # Center vertically
# Draw the text
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
image_text.save(out_image_with_text)
#print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer')
#######text_by_textregion = []
#######for ind in unique_cropped_lines_region_indexer:
#######ind = np.array(cropped_lines_region_indexer)==ind
#######extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
#######text_by_textregion.append(" ".join(extracted_texts_merged_un))
text_by_textregion = []
for ind in unique_cropped_lines_region_indexer:
ind = np.array(cropped_lines_region_indexer) == ind
extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
if len(extracted_texts_merged_un)>1:
text_by_textregion_ind = ""
next_glue = ""
for indt in range(len(extracted_texts_merged_un)):
if (extracted_texts_merged_un[indt].endswith('') or
extracted_texts_merged_un[indt].endswith('-') or
extracted_texts_merged_un[indt].endswith('¬')):
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1]
next_glue = ""
else:
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt]
next_glue = " "
text_by_textregion.append(text_by_textregion_ind)
else:
text_by_textregion.append(" ".join(extracted_texts_merged_un))
indexer = 0
indexer_textregion = 0
for nn in root1.iter(region_tags):
#id_textregion = nn.attrib['id']
#id_textregions.append(id_textregion)
#textregions_by_existing_ids.append(text_by_textregion[indexer_textregion])
is_textregion_text = False
for childtest in nn:
if childtest.tag.endswith("TextEquiv"):
is_textregion_text = True
if not is_textregion_text:
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
has_textline = False
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
is_textline_text = False
for childtest2 in child_textregion:
if childtest2.tag.endswith("TextEquiv"):
is_textline_text = True
if not is_textline_text:
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
unicode_textline.text = extracted_texts_merged[indexer]
else:
for childtest3 in child_textregion:
if childtest3.tag.endswith("TextEquiv"):
for child_uc in childtest3:
if child_uc.tag.endswith("Unicode"):
##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
child_uc.text = extracted_texts_merged[indexer]
indexer = indexer + 1
has_textline = True
if has_textline:
if is_textregion_text:
for child4 in nn:
if child4.tag.endswith("TextEquiv"):
for childtr_uc in child4:
if childtr_uc.tag.endswith("Unicode"):
childtr_uc.text = text_by_textregion[indexer_textregion]
else:
unicode_textregion.text = text_by_textregion[indexer_textregion]
indexer_textregion = indexer_textregion + 1
###sample_order = [(id_to_order[tid], text)
### for tid, text in zip(id_textregions, textregions_by_existing_ids)
### if tid in id_to_order]
##ordered_texts_sample = [text for _, text in sorted(sample_order)]
##tot_page_text = ' '.join(ordered_texts_sample)
##for page_element in root1.iter(link+'Page'):
##text_page = ET.SubElement(page_element, 'TextEquiv')
##unicode_textpage = ET.SubElement(text_page, 'Unicode')
##unicode_textpage.text = tot_page_text
ET.register_namespace("",name_space)
tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None)
else:
###max_len = 280#512#280#512
###padding_token = 1500#299#1500#299
image_width = 512#max_len * 4
image_height = 32
img_size=(image_width, image_height)
for dir_img in ls_imgs:
file_name = Path(dir_img).stem
dir_xml = os.path.join(dir_xmls, file_name+'.xml')
out_file_ocr = os.path.join(dir_out, file_name+'.xml')
if os.path.exists(out_file_ocr):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", out_file_ocr)
else:
self.logger.warning("will skip input for existing output file '%s'", out_file_ocr)
continue
img = cv2.imread(dir_img)
if dir_in_bin is not None:
cropped_lines_bin = []
img_bin = cv2.imread(os.path.join(dir_in_bin, file_name+'.png'))
if dir_out_image_text:
out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png')
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text)
total_bb_coordinates = [] total_bb_coordinates = []
tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8"))
root1=tree1.getroot()
alltags=[elem.tag for elem in root1.iter()]
link=alltags[0].split('}')[0]+'}'
name_space = alltags[0].split('}')[0]
name_space = name_space.split('{')[1]
region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')])
cropped_lines = [] cropped_lines = []
img_crop_bin = None
imgs_bin = None
imgs_bin_ver_flipped = None
cropped_lines_bin = []
cropped_lines_ver_index = [] cropped_lines_ver_index = []
cropped_lines_region_indexer = [] cropped_lines_region_indexer = []
cropped_lines_meging_indexing = [] cropped_lines_meging_indexing = []
tinl = time.time()
indexer_text_region = 0 indexer_text_region = 0
indexer_textlines = 0 for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
for nn in root1.iter(region_tags):
try: try:
type_textregion = nn.attrib['type'] type_textregion = nn.attrib['type']
except: except:
@ -502,13 +312,12 @@ class Eynollah_ocr:
if type_textregion=='drop-capital': if type_textregion=='drop-capital':
angle_degrees = 0 angle_degrees = 0
if dir_out_image_text:
total_bb_coordinates.append([x,y,w,h]) total_bb_coordinates.append([x,y,w,h])
w_scaled = w * image_height/float(h) w_scaled = w * image_height/float(h)
img_poly_on_img = np.copy(img) img_poly_on_img = np.copy(img)
if dir_in_bin is not None: if img_bin:
img_poly_on_img_bin = np.copy(img_bin) img_poly_on_img_bin = np.copy(img_bin)
img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :]
@ -528,7 +337,7 @@ class Eynollah_ocr:
better_des_slope = get_orientation_moments(textline_coords) better_des_slope = get_orientation_moments(textline_coords)
img_crop = rotate_image_with_padding(img_crop, better_des_slope) img_crop = rotate_image_with_padding(img_crop, better_des_slope)
if dir_in_bin is not None: if img_bin:
img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope)
mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) mask_poly = rotate_image_with_padding(mask_poly, better_des_slope)
@ -542,13 +351,13 @@ class Eynollah_ocr:
if not self.do_not_mask_with_textline_contour: if not self.do_not_mask_with_textline_contour:
img_crop[mask_poly==0] = 255 img_crop[mask_poly==0] = 255
if dir_in_bin is not None: if img_bin:
img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :]
if not self.do_not_mask_with_textline_contour: if not self.do_not_mask_with_textline_contour:
img_crop_bin[mask_poly==0] = 255 img_crop_bin[mask_poly==0] = 255
if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90:
if dir_in_bin is not None: if img_bin:
img_crop, img_crop_bin = \ img_crop, img_crop_bin = \
break_curved_line_into_small_pieces_and_then_merge( break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly, img_crop_bin) img_crop, mask_poly, img_crop_bin)
@ -561,14 +370,14 @@ class Eynollah_ocr:
better_des_slope = 0 better_des_slope = 0
if not self.do_not_mask_with_textline_contour: if not self.do_not_mask_with_textline_contour:
img_crop[mask_poly==0] = 255 img_crop[mask_poly==0] = 255
if dir_in_bin is not None: if img_bin:
if not self.do_not_mask_with_textline_contour: if not self.do_not_mask_with_textline_contour:
img_crop_bin[mask_poly==0] = 255 img_crop_bin[mask_poly==0] = 255
if type_textregion=='drop-capital': if type_textregion=='drop-capital':
pass pass
else: else:
if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90:
if dir_in_bin is not None: if img_bin:
img_crop, img_crop_bin = \ img_crop, img_crop_bin = \
break_curved_line_into_small_pieces_and_then_merge( break_curved_line_into_small_pieces_and_then_merge(
img_crop, mask_poly, img_crop_bin) img_crop, mask_poly, img_crop_bin)
@ -587,13 +396,13 @@ class Eynollah_ocr:
cropped_lines_ver_index.append(0) cropped_lines_ver_index.append(0)
cropped_lines_meging_indexing.append(0) cropped_lines_meging_indexing.append(0)
if dir_in_bin is not None: if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model( img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop_bin, image_height, image_width) img_crop_bin, image_height, image_width)
cropped_lines_bin.append(img_fin) cropped_lines_bin.append(img_fin)
else: else:
splited_images, splited_images_bin = return_textlines_split_if_needed( splited_images, splited_images_bin = return_textlines_split_if_needed(
img_crop, img_crop_bin if dir_in_bin is not None else None) img_crop, img_crop_bin if img_bin else None)
if splited_images: if splited_images:
img_fin = preprocess_and_resize_image_for_ocrcnn_model( img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images[0], image_height, image_width) splited_images[0], image_height, image_width)
@ -616,7 +425,7 @@ class Eynollah_ocr:
else: else:
cropped_lines_ver_index.append(0) cropped_lines_ver_index.append(0)
if dir_in_bin is not None: if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model( img_fin = preprocess_and_resize_image_for_ocrcnn_model(
splited_images_bin[0], image_height, image_width) splited_images_bin[0], image_height, image_width)
cropped_lines_bin.append(img_fin) cropped_lines_bin.append(img_fin)
@ -635,7 +444,7 @@ class Eynollah_ocr:
else: else:
cropped_lines_ver_index.append(0) cropped_lines_ver_index.append(0)
if dir_in_bin is not None: if img_bin:
img_fin = preprocess_and_resize_image_for_ocrcnn_model( img_fin = preprocess_and_resize_image_for_ocrcnn_model(
img_crop_bin, image_height, image_width) img_crop_bin, image_height, image_width)
cropped_lines_bin.append(img_fin) cropped_lines_bin.append(img_fin)
@ -648,6 +457,7 @@ class Eynollah_ocr:
n_iterations = math.ceil(len(cropped_lines) / self.b_s) n_iterations = math.ceil(len(cropped_lines) / self.b_s)
# FIXME: copy pasta
for i in range(n_iterations): for i in range(n_iterations):
if i==(n_iterations-1): if i==(n_iterations-1):
n_start = i*self.b_s n_start = i*self.b_s
@ -667,7 +477,7 @@ class Eynollah_ocr:
else: else:
imgs_ver_flipped = None imgs_ver_flipped = None
if dir_in_bin is not None: if img_bin:
imgs_bin = cropped_lines_bin[n_start:] imgs_bin = cropped_lines_bin[n_start:]
imgs_bin = np.array(imgs_bin) imgs_bin = np.array(imgs_bin)
imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3)
@ -697,7 +507,7 @@ class Eynollah_ocr:
imgs_ver_flipped = None imgs_ver_flipped = None
if dir_in_bin is not None: if img_bin:
imgs_bin = cropped_lines_bin[n_start:n_end] imgs_bin = cropped_lines_bin[n_start:n_end]
imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3)
@ -743,7 +553,8 @@ class Eynollah_ocr:
indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher]
preds[indices_to_be_replaced,:,:] = \ preds[indices_to_be_replaced,:,:] = \
preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] preds_flipped[indices_where_flipped_conf_value_is_higher, :, :]
if dir_in_bin is not None:
if img_bin:
preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0) preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0)
if len(indices_ver)>0: if len(indices_ver)>0:
@ -797,7 +608,6 @@ class Eynollah_ocr:
extracted_texts.append("") extracted_texts.append("")
extracted_conf_value.append(0) extracted_conf_value.append(0)
del cropped_lines del cropped_lines
if dir_in_bin is not None:
del cropped_lines_bin del cropped_lines_bin
gc.collect() gc.collect()
@ -808,24 +618,46 @@ class Eynollah_ocr:
else None else None
for ind in range(len(cropped_lines_meging_indexing))] for ind in range(len(cropped_lines_meging_indexing))]
extracted_conf_value_merged = [extracted_conf_value[ind] extracted_conf_value_merged = [extracted_conf_value[ind] # type: ignore
if cropped_lines_meging_indexing[ind]==0 if cropped_lines_meging_indexing[ind]==0
else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2.
if cropped_lines_meging_indexing[ind]==1 if cropped_lines_meging_indexing[ind]==1
else None else None
for ind in range(len(cropped_lines_meging_indexing))] for ind in range(len(cropped_lines_meging_indexing))]
extracted_conf_value_merged = [extracted_conf_value_merged[ind_cfm] extracted_conf_value_merged: List[float] = [extracted_conf_value_merged[ind_cfm]
for ind_cfm in range(len(extracted_texts_merged)) for ind_cfm in range(len(extracted_texts_merged))
if extracted_texts_merged[ind_cfm] is not None] if extracted_texts_merged[ind_cfm] is not None]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
if dir_out_image_text: extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "Charis-Regular.ttf" return EynollahOcrResult(
with importlib_resources.as_file(font) as font: extracted_texts_merged=extracted_texts_merged,
font = ImageFont.truetype(font=font, size=40) extracted_conf_value_merged=extracted_conf_value_merged,
cropped_lines_region_indexer=cropped_lines_region_indexer,
total_bb_coordinates=total_bb_coordinates,
)
def write_ocr(
self,
*,
result: EynollahOcrResult,
page_tree: ET.ElementTree,
out_file_ocr,
page_ns,
img,
out_image_with_text,
):
cropped_lines_region_indexer = result.cropped_lines_region_indexer
total_bb_coordinates = result.total_bb_coordinates
extracted_texts_merged = result.extracted_texts_merged
extracted_conf_value_merged = result.extracted_conf_value_merged
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text)
font = get_font()
for indexer_text, bb_ind in enumerate(total_bb_coordinates): for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0] x_bb = bb_ind[0]
@ -868,25 +700,10 @@ class Eynollah_ocr:
text_by_textregion.append(text_by_textregion_ind) text_by_textregion.append(text_by_textregion_ind)
else: else:
text_by_textregion.append(" ".join(extracted_texts_merged_un)) text_by_textregion.append(" ".join(extracted_texts_merged_un))
#print(text_by_textregion, 'text_by_textregiontext_by_textregiontext_by_textregiontext_by_textregiontext_by_textregion')
###index_tot_regions = []
###tot_region_ref = []
###for jj in root1.iter(link+'RegionRefIndexed'):
###index_tot_regions.append(jj.attrib['index'])
###tot_region_ref.append(jj.attrib['regionRef'])
###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)}
#id_textregions = []
#textregions_by_existing_ids = []
indexer = 0 indexer = 0
indexer_textregion = 0 indexer_textregion = 0
for nn in root1.iter(region_tags): for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
#id_textregion = nn.attrib['id']
#id_textregions.append(id_textregion)
#textregions_by_existing_ids.append(text_by_textregion[indexer_textregion])
is_textregion_text = False is_textregion_text = False
for childtest in nn: for childtest in nn:
@ -910,6 +727,7 @@ class Eynollah_ocr:
if not is_textline_text: if not is_textline_text:
text_subelement = ET.SubElement(child_textregion, 'TextEquiv') text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
if extracted_conf_value_merged:
text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
unicode_textline = ET.SubElement(text_subelement, 'Unicode') unicode_textline = ET.SubElement(text_subelement, 'Unicode')
unicode_textline.text = extracted_texts_merged[indexer] unicode_textline.text = extracted_texts_merged[indexer]
@ -918,8 +736,8 @@ class Eynollah_ocr:
if childtest3.tag.endswith("TextEquiv"): if childtest3.tag.endswith("TextEquiv"):
for child_uc in childtest3: for child_uc in childtest3:
if child_uc.tag.endswith("Unicode"): if child_uc.tag.endswith("Unicode"):
childtest3.set('conf', if extracted_conf_value_merged:
f"{extracted_conf_value_merged[indexer]:.2f}") childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
child_uc.text = extracted_texts_merged[indexer] child_uc.text = extracted_texts_merged[indexer]
indexer = indexer + 1 indexer = indexer + 1
@ -935,18 +753,85 @@ class Eynollah_ocr:
unicode_textregion.text = text_by_textregion[indexer_textregion] unicode_textregion.text = text_by_textregion[indexer_textregion]
indexer_textregion = indexer_textregion + 1 indexer_textregion = indexer_textregion + 1
###sample_order = [(id_to_order[tid], text) ET.register_namespace("",page_ns)
### for tid, text in zip(id_textregions, textregions_by_existing_ids) page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None)
### if tid in id_to_order]
##ordered_texts_sample = [text for _, text in sorted(sample_order)] def run(
##tot_page_text = ' '.join(ordered_texts_sample) self,
*,
overwrite: bool = False,
dir_in: Optional[str] = None,
dir_in_bin: Optional[str] = None,
image_filename: Optional[str] = None,
dir_xmls: str,
dir_out_image_text: Optional[str] = None,
dir_out: str,
):
"""
Run OCR.
##for page_element in root1.iter(link+'Page'): Args:
##text_page = ET.SubElement(page_element, 'TextEquiv')
##unicode_textpage = ET.SubElement(text_page, 'Unicode')
##unicode_textpage.text = tot_page_text
ET.register_namespace("",name_space) dir_in_bin (str): Prediction with RGB and binarized images for selected pages, should not be the default
tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) """
#print("Job done in %.1fs", time.time() - t0) if dir_in:
ls_imgs = [os.path.join(dir_in, image_filename)
for image_filename in filter(is_image_filename,
os.listdir(dir_in))]
else:
assert image_filename
ls_imgs = [image_filename]
for img_filename in ls_imgs:
file_stem = Path(img_filename).stem
page_file_in = os.path.join(dir_xmls, file_stem+'.xml')
out_file_ocr = os.path.join(dir_out, file_stem+'.xml')
if os.path.exists(out_file_ocr):
if overwrite:
self.logger.warning("will overwrite existing output file '%s'", out_file_ocr)
else:
self.logger.warning("will skip input for existing output file '%s'", out_file_ocr)
return
img = cv2.imread(img_filename)
page_tree = ET.parse(page_file_in, parser = ET.XMLParser(encoding="utf-8"))
page_ns = etree_namespace_for_element_tag(page_tree.getroot().tag)
out_image_with_text = None
if dir_out_image_text:
out_image_with_text = os.path.join(dir_out_image_text, file_stem + '.png')
img_bin = None
if dir_in_bin:
img_bin = cv2.imread(os.path.join(dir_in_bin, file_stem+'.png'))
if self.tr_ocr:
result = self.run_trocr(
img=img,
page_tree=page_tree,
page_ns=page_ns,
tr_ocr_input_height_and_width = 384
)
else:
result = self.run_cnn(
img=img,
page_tree=page_tree,
page_ns=page_ns,
img_bin=img_bin,
image_width=512,
image_height=32,
)
self.write_ocr(
result=result,
img=img,
page_tree=page_tree,
page_ns=page_ns,
out_file_ocr=out_file_ocr,
out_image_with_text=out_image_with_text,
)

View file

@ -0,0 +1,16 @@
# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files
import sys
from PIL import ImageFont
if sys.version_info < (3, 10):
import importlib_resources
else:
import importlib.resources as importlib_resources
def get_font():
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40)

View file

@ -128,6 +128,7 @@ def return_textlines_split_if_needed(textline_image, textline_image_bin=None):
return [image1, image2], None return [image1, image2], None
else: else:
return None, None return None, None
def preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width): def preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width):
if img.shape[0]==0 or img.shape[1]==0: if img.shape[0]==0 or img.shape[1]==0:
img_fin = np.ones((image_height, image_width, 3)) img_fin = np.ones((image_height, image_width, 3))

View file

@ -88,3 +88,7 @@ def order_and_id_of_texts(found_polygons_text_region, found_polygons_text_region
order_of_texts.append(interest) order_of_texts.append(interest)
return order_of_texts, id_of_texts return order_of_texts, id_of_texts
def etree_namespace_for_element_tag(tag: str):
right = tag.find('}')
return tag[1:right]