diff --git a/src/eynollah/cli/cli_ocr.py b/src/eynollah/cli/cli_ocr.py index 9bb8620..eb94dcc 100644 --- a/src/eynollah/cli/cli_ocr.py +++ b/src/eynollah/cli/cli_ocr.py @@ -88,7 +88,6 @@ def ocr_cli( tr_ocr, do_not_mask_with_textline_contour, batch_size, - dataset_abbrevation, min_conf_value_of_textline_text, ): """ @@ -101,7 +100,6 @@ def ocr_cli( tr_ocr=tr_ocr, do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, batch_size=batch_size, - pref_of_dataset=dataset_abbrevation, min_conf_value_of_textline_text=min_conf_value_of_textline_text) eynollah_ocr.run(overwrite=overwrite, dir_in=dir_in, diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 52dcca9..3c918e5 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -1,24 +1,22 @@ # FIXME: fix all of those... -# pyright: reportPossiblyUnboundVariable=false -# pyright: reportOptionalMemberAccess=false -# pyright: reportArgumentType=false -# pyright: reportCallIssue=false # pyright: reportOptionalSubscript=false from logging import Logger, getLogger -from typing import Optional +from typing import List, Optional from pathlib import Path import os import gc -import sys import math -import time +from dataclasses import dataclass import cv2 -import xml.etree.ElementTree as ET -from PIL import Image, ImageDraw, ImageFont +from cv2.typing import MatLike +from xml.etree import ElementTree as ET +from PIL import Image, ImageDraw import numpy as np from eynollah.model_zoo import EynollahModelZoo +from eynollah.utils.font import get_font +from eynollah.utils.xml import etree_namespace_for_element_tag try: import torch except ImportError: @@ -38,11 +36,13 @@ from .utils.utils_ocr import ( rotate_image_with_padding, ) -# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files -if sys.version_info < (3, 10): - import importlib_resources -else: - import importlib.resources as importlib_resources +# TODO: refine typing +@dataclass +class EynollahOcrResult: + extracted_texts_merged: List + extracted_conf_value_merged: Optional[List] + cropped_lines_region_indexer: List + total_bb_coordinates:List class Eynollah_ocr: def __init__( @@ -76,6 +76,7 @@ class Eynollah_ocr: @property def device(self): + assert torch if torch.cuda.is_available(): self.logger.info("Using GPU acceleration") return torch.device("cuda:0") @@ -83,870 +84,754 @@ class Eynollah_ocr: self.logger.info("Using CPU processing") return torch.device("cpu") - def run(self, overwrite: bool = False, - dir_in: Optional[str] = None, - # Prediction with RGB and binarized images for selected pages, should not be the default - dir_in_bin: Optional[str] = None, - image_filename: Optional[str] = None, - dir_xmls: Optional[str] = None, - dir_out_image_text: Optional[str] = None, - dir_out: Optional[str] = None, + def run_trocr( + self, + *, + img: MatLike, + page_tree: ET.ElementTree, + page_ns, + tr_ocr_input_height_and_width, + ) -> EynollahOcrResult: + + total_bb_coordinates = [] + + + cropped_lines = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + extracted_texts = [] + + indexer_text_region = 0 + indexer_b_s = 0 + + for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [int(x.split(',')[0]), + int(x.split(',')[1]) ] + for x in p_h] ) + x,y,w,h = cv2.boundingRect(textline_coords) + + total_bb_coordinates.append([x,y,w,h]) + + h2w_ratio = h/float(w) + + img_poly_on_img = np.copy(img) + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + img_crop[mask_poly==0] = 255 + + self.logger.debug("processing %d lines for '%s'", + len(cropped_lines), nn.attrib['id']) + if h2w_ratio > 0.1: + cropped_lines.append(resize_image(img_crop, + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width) ) + cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + else: + splited_images, _ = return_textlines_split_if_needed(img_crop, None) + #print(splited_images) + if splited_images: + cropped_lines.append(resize_image(splited_images[0], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + cropped_lines.append(resize_image(splited_images[1], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(-1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + else: + cropped_lines.append(img_crop) + cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + + indexer_text_region = indexer_text_region +1 + + if indexer_b_s!=0: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + ####extracted_texts = [] + ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + ####for i in range(n_iterations): + ####if i==(n_iterations-1): + ####n_start = i*self.b_s + ####imgs = cropped_lines[n_start:] + ####else: + ####n_start = i*self.b_s + ####n_end = (i+1)*self.b_s + ####imgs = cropped_lines[n_start:n_end] + ####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + ####generated_ids_merged = self.model_ocr.generate( + #### pixel_values_merged.to(self.device)) + ####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + #### generated_ids_merged, skip_special_tokens=True) + + ####extracted_texts = extracted_texts + generated_text_merged + + del cropped_lines + gc.collect() + + extracted_texts_merged = [extracted_texts[ind] + if cropped_lines_meging_indexing[ind]==0 + else extracted_texts[ind]+" "+extracted_texts[ind+1] + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + #print(extracted_texts_merged, len(extracted_texts_merged)) + + return EynollahOcrResult( + extracted_texts_merged=extracted_texts_merged, + extracted_conf_value_merged=None, + cropped_lines_region_indexer=cropped_lines_region_indexer, + total_bb_coordinates=total_bb_coordinates, + ) + + def run_cnn( + self, + *, + img: MatLike, + img_bin: Optional[MatLike], + page_tree: ET.ElementTree, + page_ns, + image_width, + image_height, + ) -> EynollahOcrResult: + + total_bb_coordinates = [] + + cropped_lines = [] + img_crop_bin = None + imgs_bin = None + imgs_bin_ver_flipped = None + cropped_lines_bin = [] + cropped_lines_ver_index = [] + cropped_lines_region_indexer = [] + cropped_lines_meging_indexing = [] + + indexer_text_region = 0 + for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): + try: + type_textregion = nn.attrib['type'] + except: + type_textregion = 'paragraph' + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + for child_textlines in child_textregion: + if child_textlines.tag.endswith("Coords"): + cropped_lines_region_indexer.append(indexer_text_region) + p_h=child_textlines.attrib['points'].split(' ') + textline_coords = np.array( [ [int(x.split(',')[0]), + int(x.split(',')[1]) ] + for x in p_h] ) + + x,y,w,h = cv2.boundingRect(textline_coords) + + angle_radians = math.atan2(h, w) + # Convert to degrees + angle_degrees = math.degrees(angle_radians) + if type_textregion=='drop-capital': + angle_degrees = 0 + + total_bb_coordinates.append([x,y,w,h]) + + w_scaled = w * image_height/float(h) + + img_poly_on_img = np.copy(img) + if img_bin: + img_poly_on_img_bin = np.copy(img_bin) + img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] + + mask_poly = np.zeros(img.shape) + mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) + + + mask_poly = mask_poly[y:y+h, x:x+w, :] + img_crop = img_poly_on_img[y:y+h, x:x+w, :] + + # print(file_name, angle_degrees, w*h, + # mask_poly[:,:,0].sum(), + # mask_poly[:,:,0].sum() /float(w*h) , + # 'didi') + + if angle_degrees > 3: + better_des_slope = get_orientation_moments(textline_coords) + + img_crop = rotate_image_with_padding(img_crop, better_des_slope) + if img_bin: + img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) + + mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) + mask_poly = mask_poly.astype('uint8') + + #new bounding box + x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) + + mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] + img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] + + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + if img_bin: + img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] + if not self.do_not_mask_with_textline_contour: + img_crop_bin[mask_poly==0] = 255 + + if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: + if img_bin: + img_crop, img_crop_bin = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly, img_crop_bin) + else: + img_crop, _ = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly) + + else: + better_des_slope = 0 + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 + if img_bin: + if not self.do_not_mask_with_textline_contour: + img_crop_bin[mask_poly==0] = 255 + if type_textregion=='drop-capital': + pass + else: + if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: + if img_bin: + img_crop, img_crop_bin = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly, img_crop_bin) + else: + img_crop, _ = \ + break_curved_line_into_small_pieces_and_then_merge( + img_crop, mask_poly) + + if w_scaled < 750:#1.5*image_width: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop, image_height, image_width) + cropped_lines.append(img_fin) + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + cropped_lines_meging_indexing.append(0) + if img_bin: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop_bin, image_height, image_width) + cropped_lines_bin.append(img_fin) + else: + splited_images, splited_images_bin = return_textlines_split_if_needed( + img_crop, img_crop_bin if img_bin else None) + if splited_images: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images[0], image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(1) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images[1], image_height, image_width) + + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(-1) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + if img_bin: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images_bin[0], image_height, image_width) + cropped_lines_bin.append(img_fin) + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + splited_images_bin[1], image_height, image_width) + cropped_lines_bin.append(img_fin) + + else: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop, image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(0) + + if abs(better_des_slope) > 45: + cropped_lines_ver_index.append(1) + else: + cropped_lines_ver_index.append(0) + + if img_bin: + img_fin = preprocess_and_resize_image_for_ocrcnn_model( + img_crop_bin, image_height, image_width) + cropped_lines_bin.append(img_fin) + + + indexer_text_region = indexer_text_region +1 + + extracted_texts = [] + extracted_conf_value = [] + + n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + # FIXME: copy pasta + for i in range(n_iterations): + if i==(n_iterations-1): + n_start = i*self.b_s + imgs = cropped_lines[n_start:] + imgs = np.array(imgs) + imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) + + ver_imgs = np.array( cropped_lines_ver_index[n_start:] ) + indices_ver = np.where(ver_imgs == 1)[0] + + #print(indices_ver, 'indices_ver') + if len(indices_ver)>0: + imgs_ver_flipped = imgs[indices_ver, : ,: ,:] + imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + + else: + imgs_ver_flipped = None + + if img_bin: + imgs_bin = cropped_lines_bin[n_start:] + imgs_bin = np.array(imgs_bin) + imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) + + if len(indices_ver)>0: + imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] + imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + + else: + imgs_bin_ver_flipped = None + else: + n_start = i*self.b_s + n_end = (i+1)*self.b_s + imgs = cropped_lines[n_start:n_end] + imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) + + ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] ) + indices_ver = np.where(ver_imgs == 1)[0] + #print(indices_ver, 'indices_ver') + + if len(indices_ver)>0: + imgs_ver_flipped = imgs[indices_ver, : ,: ,:] + imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + else: + imgs_ver_flipped = None + + + if img_bin: + imgs_bin = cropped_lines_bin[n_start:n_end] + imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) + + + if len(indices_ver)>0: + imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] + imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] + #print(imgs_ver_flipped, 'imgs_ver_flipped') + else: + imgs_bin_ver_flipped = None + + + self.logger.debug("processing next %d lines", len(imgs)) + preds = self.model_zoo.get('ocr').predict(imgs, verbose=0) + + if len(indices_ver)>0: + preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0) + preds_max_fliped = np.max(preds_flipped, axis=2 ) + preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) + pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character + masked_means_flipped = \ + np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) + masked_means_flipped[np.isnan(masked_means_flipped)] = 0 + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + masked_means[np.isnan(masked_means)] = 0 + + masked_means_ver = masked_means[indices_ver] + #print(masked_means_ver, 'pred_max_not_unk') + + indices_where_flipped_conf_value_is_higher = \ + np.where(masked_means_flipped > masked_means_ver)[0] + + #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') + if len(indices_where_flipped_conf_value_is_higher)>0: + indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] + preds[indices_to_be_replaced,:,:] = \ + preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] + + if img_bin: + preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0) + + if len(indices_ver)>0: + preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0) + preds_max_fliped = np.max(preds_flipped, axis=2 ) + preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) + pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character + masked_means_flipped = \ + np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) + masked_means_flipped[np.isnan(masked_means_flipped)] = 0 + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + masked_means[np.isnan(masked_means)] = 0 + + masked_means_ver = masked_means[indices_ver] + #print(masked_means_ver, 'pred_max_not_unk') + + indices_where_flipped_conf_value_is_higher = \ + np.where(masked_means_flipped > masked_means_ver)[0] + + #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') + if len(indices_where_flipped_conf_value_is_higher)>0: + indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] + preds_bin[indices_to_be_replaced,:,:] = \ + preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] + + preds = (preds + preds_bin) / 2. + + pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char')) + + preds_max = np.max(preds, axis=2 ) + preds_max_args = np.argmax(preds, axis=2 ) + pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character + masked_means = \ + np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ + np.sum(pred_max_not_unk_mask_bool, axis=1) + + for ib in range(imgs.shape[0]): + pred_texts_ib = pred_texts[ib].replace("[UNK]", "") + if masked_means[ib] >= self.min_conf_value_of_textline_text: + extracted_texts.append(pred_texts_ib) + extracted_conf_value.append(masked_means[ib]) + else: + extracted_texts.append("") + extracted_conf_value.append(0) + del cropped_lines + del cropped_lines_bin + gc.collect() + + extracted_texts_merged = [extracted_texts[ind] + if cropped_lines_meging_indexing[ind]==0 + else extracted_texts[ind]+" "+extracted_texts[ind+1] + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_conf_value_merged = [extracted_conf_value[ind] # type: ignore + if cropped_lines_meging_indexing[ind]==0 + else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. + if cropped_lines_meging_indexing[ind]==1 + else None + for ind in range(len(cropped_lines_meging_indexing))] + + extracted_conf_value_merged: List[float] = [extracted_conf_value_merged[ind_cfm] + for ind_cfm in range(len(extracted_texts_merged)) + if extracted_texts_merged[ind_cfm] is not None] + + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + + return EynollahOcrResult( + extracted_texts_merged=extracted_texts_merged, + extracted_conf_value_merged=extracted_conf_value_merged, + cropped_lines_region_indexer=cropped_lines_region_indexer, + total_bb_coordinates=total_bb_coordinates, + ) + + def write_ocr( + self, + *, + result: EynollahOcrResult, + page_tree: ET.ElementTree, + out_file_ocr, + page_ns, + img, + out_image_with_text, ): + cropped_lines_region_indexer = result.cropped_lines_region_indexer + total_bb_coordinates = result.total_bb_coordinates + extracted_texts_merged = result.extracted_texts_merged + extracted_conf_value_merged = result.extracted_conf_value_merged + + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + if out_image_with_text: + image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") + draw = ImageDraw.Draw(image_text) + font = get_font() + + for indexer_text, bb_ind in enumerate(total_bb_coordinates): + x_bb = bb_ind[0] + y_bb = bb_ind[1] + w_bb = bb_ind[2] + h_bb = bb_ind[3] + + font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], + font.path, w_bb, int(h_bb*0.4) ) + + ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) + + text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + + text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally + text_y = y_bb + (h_bb - text_height) // 2 # Center vertically + + # Draw the text + draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) + image_text.save(out_image_with_text) + + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + ind = np.array(cropped_lines_region_indexer)==ind + extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] + if len(extracted_texts_merged_un)>1: + text_by_textregion_ind = "" + next_glue = "" + for indt in range(len(extracted_texts_merged_un)): + if (extracted_texts_merged_un[indt].endswith('⸗') or + extracted_texts_merged_un[indt].endswith('-') or + extracted_texts_merged_un[indt].endswith('¬')): + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] + next_glue = "" + else: + text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] + next_glue = " " + text_by_textregion.append(text_by_textregion_ind) + else: + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + indexer = 0 + indexer_textregion = 0 + for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): + + is_textregion_text = False + for childtest in nn: + if childtest.tag.endswith("TextEquiv"): + is_textregion_text = True + + if not is_textregion_text: + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + + is_textline_text = False + for childtest2 in child_textregion: + if childtest2.tag.endswith("TextEquiv"): + is_textline_text = True + + + if not is_textline_text: + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + if extracted_conf_value_merged: + text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + else: + for childtest3 in child_textregion: + if childtest3.tag.endswith("TextEquiv"): + for child_uc in childtest3: + if child_uc.tag.endswith("Unicode"): + if extracted_conf_value_merged: + childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + child_uc.text = extracted_texts_merged[indexer] + + indexer = indexer + 1 + has_textline = True + if has_textline: + if is_textregion_text: + for child4 in nn: + if child4.tag.endswith("TextEquiv"): + for childtr_uc in child4: + if childtr_uc.tag.endswith("Unicode"): + childtr_uc.text = text_by_textregion[indexer_textregion] + else: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + ET.register_namespace("",page_ns) + page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None) + + def run( + self, + *, + overwrite: bool = False, + dir_in: Optional[str] = None, + dir_in_bin: Optional[str] = None, + image_filename: Optional[str] = None, + dir_xmls: str, + dir_out_image_text: Optional[str] = None, + dir_out: str, + ): + """ + Run OCR. + + Args: + + dir_in_bin (str): Prediction with RGB and binarized images for selected pages, should not be the default + """ if dir_in: ls_imgs = [os.path.join(dir_in, image_filename) - for image_filename in filter(is_image_filename, + for image_filename in filter(is_image_filename, os.listdir(dir_in))] else: assert image_filename ls_imgs = [image_filename] - if self.tr_ocr: - tr_ocr_input_height_and_width = 384 - for dir_img in ls_imgs: - file_name = Path(dir_img).stem - assert dir_xmls # FIXME: check the logic - dir_xml = os.path.join(dir_xmls, file_name+'.xml') - assert dir_out # FIXME: check the logic - out_file_ocr = os.path.join(dir_out, file_name+'.xml') - - if os.path.exists(out_file_ocr): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) - else: - self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) - continue - - img = cv2.imread(dir_img) - - if dir_out_image_text: - out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') - image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") - draw = ImageDraw.Draw(image_text) - total_bb_coordinates = [] - - ##file_name = Path(dir_xmls).stem - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' - - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] - - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - - - - cropped_lines = [] - cropped_lines_region_indexer = [] - cropped_lines_meging_indexing = [] - - extracted_texts = [] - - indexer_text_region = 0 - indexer_b_s = 0 - - for nn in root1.iter(region_tags): - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - x,y,w,h = cv2.boundingRect(textline_coords) - - if dir_out_image_text: - total_bb_coordinates.append([x,y,w,h]) - - h2w_ratio = h/float(w) - - img_poly_on_img = np.copy(img) - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - img_crop[mask_poly==0] = 255 - - self.logger.debug("processing %d lines for '%s'", - len(cropped_lines), nn.attrib['id']) - if h2w_ratio > 0.1: - cropped_lines.append(resize_image(img_crop, - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width) ) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - splited_images, _ = return_textlines_split_if_needed(img_crop, None) - #print(splited_images) - if splited_images: - cropped_lines.append(resize_image(splited_images[0], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - cropped_lines.append(resize_image(splited_images[1], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(-1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - cropped_lines.append(img_crop) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - - indexer_text_region = indexer_text_region +1 - - if indexer_b_s!=0: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - ####extracted_texts = [] - ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - ####for i in range(n_iterations): - ####if i==(n_iterations-1): - ####n_start = i*self.b_s - ####imgs = cropped_lines[n_start:] - ####else: - ####n_start = i*self.b_s - ####n_end = (i+1)*self.b_s - ####imgs = cropped_lines[n_start:n_end] - ####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - ####generated_ids_merged = self.model_ocr.generate( - #### pixel_values_merged.to(self.device)) - ####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - #### generated_ids_merged, skip_special_tokens=True) - - ####extracted_texts = extracted_texts + generated_text_merged - - del cropped_lines - gc.collect() - - extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - #print(extracted_texts_merged, len(extracted_texts_merged)) - - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - if dir_out_image_text: - - #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "Charis-Regular.ttf" - with importlib_resources.as_file(font) as font: - font = ImageFont.truetype(font=font, size=40) - - for indexer_text, bb_ind in enumerate(total_bb_coordinates): - - - x_bb = bb_ind[0] - y_bb = bb_ind[1] - w_bb = bb_ind[2] - h_bb = bb_ind[3] - - font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], - font.path, w_bb, int(h_bb*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - - text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally - text_y = y_bb + (h_bb - text_height) // 2 # Center vertically - - # Draw the text - draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) - image_text.save(out_image_with_text) - - #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') - #######text_by_textregion = [] - #######for ind in unique_cropped_lines_region_indexer: - #######ind = np.array(cropped_lines_region_indexer)==ind - #######extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - #######text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer) == ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - #id_textregion = nn.attrib['id'] - #id_textregions.append(id_textregion) - #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - ##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - ##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] - else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 - - ###sample_order = [(id_to_order[tid], text) - ### for tid, text in zip(id_textregions, textregions_by_existing_ids) - ### if tid in id_to_order] - - ##ordered_texts_sample = [text for _, text in sorted(sample_order)] - ##tot_page_text = ' '.join(ordered_texts_sample) - - ##for page_element in root1.iter(link+'Page'): - ##text_page = ET.SubElement(page_element, 'TextEquiv') - ##unicode_textpage = ET.SubElement(text_page, 'Unicode') - ##unicode_textpage.text = tot_page_text - - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) - else: - ###max_len = 280#512#280#512 - ###padding_token = 1500#299#1500#299 - image_width = 512#max_len * 4 - image_height = 32 - - - img_size=(image_width, image_height) + for img_filename in ls_imgs: + file_stem = Path(img_filename).stem + page_file_in = os.path.join(dir_xmls, file_stem+'.xml') + out_file_ocr = os.path.join(dir_out, file_stem+'.xml') - for dir_img in ls_imgs: - file_name = Path(dir_img).stem - dir_xml = os.path.join(dir_xmls, file_name+'.xml') - out_file_ocr = os.path.join(dir_out, file_name+'.xml') + if os.path.exists(out_file_ocr): + if overwrite: + self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) + else: + self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) + return - if os.path.exists(out_file_ocr): - if overwrite: - self.logger.warning("will overwrite existing output file '%s'", out_file_ocr) - else: - self.logger.warning("will skip input for existing output file '%s'", out_file_ocr) - continue - - img = cv2.imread(dir_img) - if dir_in_bin is not None: - cropped_lines_bin = [] - img_bin = cv2.imread(os.path.join(dir_in_bin, file_name+'.png')) - - if dir_out_image_text: - out_image_with_text = os.path.join(dir_out_image_text, file_name+'.png') - image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") - draw = ImageDraw.Draw(image_text) - total_bb_coordinates = [] + img = cv2.imread(img_filename) - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) - root1=tree1.getroot() - alltags=[elem.tag for elem in root1.iter()] - link=alltags[0].split('}')[0]+'}' + page_tree = ET.parse(page_file_in, parser = ET.XMLParser(encoding="utf-8")) + page_ns = etree_namespace_for_element_tag(page_tree.getroot().tag) - name_space = alltags[0].split('}')[0] - name_space = name_space.split('{')[1] + out_image_with_text = None + if dir_out_image_text: + out_image_with_text = os.path.join(dir_out_image_text, file_stem + '.png') - region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')]) - - cropped_lines = [] - cropped_lines_ver_index = [] - cropped_lines_region_indexer = [] - cropped_lines_meging_indexing = [] - - tinl = time.time() - indexer_text_region = 0 - indexer_textlines = 0 - for nn in root1.iter(region_tags): - try: - type_textregion = nn.attrib['type'] - except: - type_textregion = 'paragraph' - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - - x,y,w,h = cv2.boundingRect(textline_coords) - - angle_radians = math.atan2(h, w) - # Convert to degrees - angle_degrees = math.degrees(angle_radians) - if type_textregion=='drop-capital': - angle_degrees = 0 - - if dir_out_image_text: - total_bb_coordinates.append([x,y,w,h]) - - w_scaled = w * image_height/float(h) - - img_poly_on_img = np.copy(img) - if dir_in_bin is not None: - img_poly_on_img_bin = np.copy(img_bin) - img_crop_bin = img_poly_on_img_bin[y:y+h, x:x+w, :] - - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - - # print(file_name, angle_degrees, w*h, - # mask_poly[:,:,0].sum(), - # mask_poly[:,:,0].sum() /float(w*h) , - # 'didi') - - if angle_degrees > 3: - better_des_slope = get_orientation_moments(textline_coords) - - img_crop = rotate_image_with_padding(img_crop, better_des_slope) - if dir_in_bin is not None: - img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope) - - mask_poly = rotate_image_with_padding(mask_poly, better_des_slope) - mask_poly = mask_poly.astype('uint8') - - #new bounding box - x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) - - mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] - img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] - - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - if dir_in_bin is not None: - img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] - if not self.do_not_mask_with_textline_contour: - img_crop_bin[mask_poly==0] = 255 - - if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: - if dir_in_bin is not None: - img_crop, img_crop_bin = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly, img_crop_bin) - else: - img_crop, _ = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly) - - else: - better_des_slope = 0 - if not self.do_not_mask_with_textline_contour: - img_crop[mask_poly==0] = 255 - if dir_in_bin is not None: - if not self.do_not_mask_with_textline_contour: - img_crop_bin[mask_poly==0] = 255 - if type_textregion=='drop-capital': - pass - else: - if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90: - if dir_in_bin is not None: - img_crop, img_crop_bin = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly, img_crop_bin) - else: - img_crop, _ = \ - break_curved_line_into_small_pieces_and_then_merge( - img_crop, mask_poly) - - if w_scaled < 750:#1.5*image_width: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop, image_height, image_width) - cropped_lines.append(img_fin) - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - cropped_lines_meging_indexing.append(0) - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop_bin, image_height, image_width) - cropped_lines_bin.append(img_fin) - else: - splited_images, splited_images_bin = return_textlines_split_if_needed( - img_crop, img_crop_bin if dir_in_bin is not None else None) - if splited_images: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images[0], image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(1) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images[1], image_height, image_width) - - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(-1) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images_bin[0], image_height, image_width) - cropped_lines_bin.append(img_fin) - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - splited_images_bin[1], image_height, image_width) - cropped_lines_bin.append(img_fin) - - else: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop, image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(0) - - if abs(better_des_slope) > 45: - cropped_lines_ver_index.append(1) - else: - cropped_lines_ver_index.append(0) - - if dir_in_bin is not None: - img_fin = preprocess_and_resize_image_for_ocrcnn_model( - img_crop_bin, image_height, image_width) - cropped_lines_bin.append(img_fin) - + img_bin = None + if dir_in_bin: + img_bin = cv2.imread(os.path.join(dir_in_bin, file_stem+'.png')) - indexer_text_region = indexer_text_region +1 - - extracted_texts = [] - extracted_conf_value = [] - n_iterations = math.ceil(len(cropped_lines) / self.b_s) + if self.tr_ocr: + result = self.run_trocr( + img=img, + page_tree=page_tree, + page_ns=page_ns, - for i in range(n_iterations): - if i==(n_iterations-1): - n_start = i*self.b_s - imgs = cropped_lines[n_start:] - imgs = np.array(imgs) - imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) - - ver_imgs = np.array( cropped_lines_ver_index[n_start:] ) - indices_ver = np.where(ver_imgs == 1)[0] - - #print(indices_ver, 'indices_ver') - if len(indices_ver)>0: - imgs_ver_flipped = imgs[indices_ver, : ,: ,:] - imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - - else: - imgs_ver_flipped = None - - if dir_in_bin is not None: - imgs_bin = cropped_lines_bin[n_start:] - imgs_bin = np.array(imgs_bin) - imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) - - if len(indices_ver)>0: - imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] - imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - - else: - imgs_bin_ver_flipped = None - else: - n_start = i*self.b_s - n_end = (i+1)*self.b_s - imgs = cropped_lines[n_start:n_end] - imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) - - ver_imgs = np.array( cropped_lines_ver_index[n_start:n_end] ) - indices_ver = np.where(ver_imgs == 1)[0] - #print(indices_ver, 'indices_ver') - - if len(indices_ver)>0: - imgs_ver_flipped = imgs[indices_ver, : ,: ,:] - imgs_ver_flipped = imgs_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - else: - imgs_ver_flipped = None + tr_ocr_input_height_and_width = 384 + ) + else: + result = self.run_cnn( + img=img, + page_tree=page_tree, + page_ns=page_ns, - - if dir_in_bin is not None: - imgs_bin = cropped_lines_bin[n_start:n_end] - imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) - - - if len(indices_ver)>0: - imgs_bin_ver_flipped = imgs_bin[indices_ver, : ,: ,:] - imgs_bin_ver_flipped = imgs_bin_ver_flipped[:,::-1,::-1,:] - #print(imgs_ver_flipped, 'imgs_ver_flipped') - else: - imgs_bin_ver_flipped = None - + img_bin=img_bin, + image_width=512, + image_height=32, + ) - self.logger.debug("processing next %d lines", len(imgs)) - preds = self.model_zoo.get('ocr').predict(imgs, verbose=0) - - if len(indices_ver)>0: - preds_flipped = self.model_zoo.get('ocr').predict(imgs_ver_flipped, verbose=0) - preds_max_fliped = np.max(preds_flipped, axis=2 ) - preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) - pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character - masked_means_flipped = \ - np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) - masked_means_flipped[np.isnan(masked_means_flipped)] = 0 - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - masked_means[np.isnan(masked_means)] = 0 - - masked_means_ver = masked_means[indices_ver] - #print(masked_means_ver, 'pred_max_not_unk') - - indices_where_flipped_conf_value_is_higher = \ - np.where(masked_means_flipped > masked_means_ver)[0] - - #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') - if len(indices_where_flipped_conf_value_is_higher)>0: - indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] - preds[indices_to_be_replaced,:,:] = \ - preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] - if dir_in_bin is not None: - preds_bin = self.model_zoo.get('ocr').predict(imgs_bin, verbose=0) - - if len(indices_ver)>0: - preds_flipped = self.model_zoo.get('ocr').predict(imgs_bin_ver_flipped, verbose=0) - preds_max_fliped = np.max(preds_flipped, axis=2 ) - preds_max_args_flipped = np.argmax(preds_flipped, axis=2 ) - pred_max_not_unk_mask_bool_flipped = preds_max_args_flipped[:,:]!=self.end_character - masked_means_flipped = \ - np.sum(preds_max_fliped * pred_max_not_unk_mask_bool_flipped, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool_flipped, axis=1) - masked_means_flipped[np.isnan(masked_means_flipped)] = 0 - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - masked_means[np.isnan(masked_means)] = 0 - - masked_means_ver = masked_means[indices_ver] - #print(masked_means_ver, 'pred_max_not_unk') - - indices_where_flipped_conf_value_is_higher = \ - np.where(masked_means_flipped > masked_means_ver)[0] - - #print(indices_where_flipped_conf_value_is_higher, 'indices_where_flipped_conf_value_is_higher') - if len(indices_where_flipped_conf_value_is_higher)>0: - indices_to_be_replaced = indices_ver[indices_where_flipped_conf_value_is_higher] - preds_bin[indices_to_be_replaced,:,:] = \ - preds_flipped[indices_where_flipped_conf_value_is_higher, :, :] - - preds = (preds + preds_bin) / 2. - - pred_texts = decode_batch_predictions(preds, self.model_zoo.get('num_to_char')) - - preds_max = np.max(preds, axis=2 ) - preds_max_args = np.argmax(preds, axis=2 ) - pred_max_not_unk_mask_bool = preds_max_args[:,:]!=self.end_character - masked_means = \ - np.sum(preds_max * pred_max_not_unk_mask_bool, axis=1) / \ - np.sum(pred_max_not_unk_mask_bool, axis=1) - - for ib in range(imgs.shape[0]): - pred_texts_ib = pred_texts[ib].replace("[UNK]", "") - if masked_means[ib] >= self.min_conf_value_of_textline_text: - extracted_texts.append(pred_texts_ib) - extracted_conf_value.append(masked_means[ib]) - else: - extracted_texts.append("") - extracted_conf_value.append(0) - del cropped_lines - if dir_in_bin is not None: - del cropped_lines_bin - gc.collect() - - extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_conf_value_merged = [extracted_conf_value[ind] - if cropped_lines_meging_indexing[ind]==0 - else (extracted_conf_value[ind]+extracted_conf_value[ind+1])/2. - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_conf_value_merged = [extracted_conf_value_merged[ind_cfm] - for ind_cfm in range(len(extracted_texts_merged)) - if extracted_texts_merged[ind_cfm] is not None] - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - if dir_out_image_text: - #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = importlib_resources.files(__package__) / "Charis-Regular.ttf" - with importlib_resources.as_file(font) as font: - font = ImageFont.truetype(font=font, size=40) - - for indexer_text, bb_ind in enumerate(total_bb_coordinates): - x_bb = bb_ind[0] - y_bb = bb_ind[1] - w_bb = bb_ind[2] - h_bb = bb_ind[3] - - font = fit_text_single_line(draw, extracted_texts_merged[indexer_text], - font.path, w_bb, int(h_bb*0.4) ) - - ##draw.rectangle([x_bb, y_bb, x_bb + w_bb, y_bb + h_bb], outline="red", width=2) - - text_bbox = draw.textbbox((0, 0), extracted_texts_merged[indexer_text], font=font) - text_width = text_bbox[2] - text_bbox[0] - text_height = text_bbox[3] - text_bbox[1] - - text_x = x_bb + (w_bb - text_width) // 2 # Center horizontally - text_y = y_bb + (h_bb - text_height) // 2 # Center vertically - - # Draw the text - draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font) - image_text.save(out_image_with_text) - - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - ind = np.array(cropped_lines_region_indexer)==ind - extracted_texts_merged_un = np.array(extracted_texts_merged)[ind] - if len(extracted_texts_merged_un)>1: - text_by_textregion_ind = "" - next_glue = "" - for indt in range(len(extracted_texts_merged_un)): - if (extracted_texts_merged_un[indt].endswith('⸗') or - extracted_texts_merged_un[indt].endswith('-') or - extracted_texts_merged_un[indt].endswith('¬')): - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1] - next_glue = "" - else: - text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt] - next_glue = " " - text_by_textregion.append(text_by_textregion_ind) - else: - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - #print(text_by_textregion, 'text_by_textregiontext_by_textregiontext_by_textregiontext_by_textregiontext_by_textregion') - - ###index_tot_regions = [] - ###tot_region_ref = [] - - ###for jj in root1.iter(link+'RegionRefIndexed'): - ###index_tot_regions.append(jj.attrib['index']) - ###tot_region_ref.append(jj.attrib['regionRef']) - - ###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)} - - #id_textregions = [] - #textregions_by_existing_ids = [] - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - #id_textregion = nn.attrib['id'] - #id_textregions.append(id_textregion) - #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) - - is_textregion_text = False - for childtest in nn: - if childtest.tag.endswith("TextEquiv"): - is_textregion_text = True - - if not is_textregion_text: - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - is_textline_text = False - for childtest2 in child_textregion: - if childtest2.tag.endswith("TextEquiv"): - is_textline_text = True - - - if not is_textline_text: - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - else: - for childtest3 in child_textregion: - if childtest3.tag.endswith("TextEquiv"): - for child_uc in childtest3: - if child_uc.tag.endswith("Unicode"): - childtest3.set('conf', - f"{extracted_conf_value_merged[indexer]:.2f}") - child_uc.text = extracted_texts_merged[indexer] - - indexer = indexer + 1 - has_textline = True - if has_textline: - if is_textregion_text: - for child4 in nn: - if child4.tag.endswith("TextEquiv"): - for childtr_uc in child4: - if childtr_uc.tag.endswith("Unicode"): - childtr_uc.text = text_by_textregion[indexer_textregion] - else: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 - - ###sample_order = [(id_to_order[tid], text) - ### for tid, text in zip(id_textregions, textregions_by_existing_ids) - ### if tid in id_to_order] - - ##ordered_texts_sample = [text for _, text in sorted(sample_order)] - ##tot_page_text = ' '.join(ordered_texts_sample) - - ##for page_element in root1.iter(link+'Page'): - ##text_page = ET.SubElement(page_element, 'TextEquiv') - ##unicode_textpage = ET.SubElement(text_page, 'Unicode') - ##unicode_textpage.text = tot_page_text - - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf-8",default_namespace=None) - #print("Job done in %.1fs", time.time() - t0) + self.write_ocr( + result=result, + img=img, + page_tree=page_tree, + page_ns=page_ns, + out_file_ocr=out_file_ocr, + out_image_with_text=out_image_with_text, + ) diff --git a/src/eynollah/utils/font.py b/src/eynollah/utils/font.py new file mode 100644 index 0000000..939933e --- /dev/null +++ b/src/eynollah/utils/font.py @@ -0,0 +1,16 @@ + +# cannot use importlib.resources until we move to 3.9+ forimportlib.resources.files +import sys +from PIL import ImageFont + +if sys.version_info < (3, 10): + import importlib_resources +else: + import importlib.resources as importlib_resources + + +def get_font(): + #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" + with importlib_resources.as_file(font) as font: + return ImageFont.truetype(font=font, size=40) diff --git a/src/eynollah/utils/utils_ocr.py b/src/eynollah/utils/utils_ocr.py index 2ba328b..b738e29 100644 --- a/src/eynollah/utils/utils_ocr.py +++ b/src/eynollah/utils/utils_ocr.py @@ -128,6 +128,7 @@ def return_textlines_split_if_needed(textline_image, textline_image_bin=None): return [image1, image2], None else: return None, None + def preprocess_and_resize_image_for_ocrcnn_model(img, image_height, image_width): if img.shape[0]==0 or img.shape[1]==0: img_fin = np.ones((image_height, image_width, 3)) diff --git a/src/eynollah/utils/xml.py b/src/eynollah/utils/xml.py index 88d1df8..ded098e 100644 --- a/src/eynollah/utils/xml.py +++ b/src/eynollah/utils/xml.py @@ -88,3 +88,7 @@ def order_and_id_of_texts(found_polygons_text_region, found_polygons_text_region order_of_texts.append(interest) return order_of_texts, id_of_texts + +def etree_namespace_for_element_tag(tag: str): + right = tag.find('}') + return tag[1:right]