From 1d67e65f11ad5266ba27262d38b4c49a7a864714 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Thu, 21 May 2026 15:48:21 +0200 Subject: [PATCH] =?UTF-8?q?trocr:=20simplify,=20batch=20over=20entire=20pa?= =?UTF-8?q?ge=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - batching over entire page instead of region-wise (underfilling batches) - avoid copied redundant code --- src/eynollah/eynollah_ocr.py | 201 +++++++------------------------- src/eynollah/utils/utils_ocr.py | 6 + 2 files changed, 51 insertions(+), 156 deletions(-) diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 4371453..747d2f5 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -14,6 +14,7 @@ from cv2.typing import MatLike from xml.etree import ElementTree as ET from PIL import Image, ImageDraw import numpy as np +from ocrd_utils import polygon_from_points, xywh_from_polygon from .eynollah import Eynollah @@ -31,6 +32,7 @@ from .utils.utils_ocr import ( preprocess_and_resize_image_for_ocrcnn_model, return_textlines_split_if_needed, rotate_image_with_padding, + batched, ) # TODO: refine typing @@ -90,143 +92,55 @@ class Eynollah_ocr(Eynollah): ) -> EynollahOcrResult: total_bb_coordinates = [] - - cropped_lines = [] cropped_lines_region_indexer = [] cropped_lines_meging_indexing = [] - extracted_texts = [] - indexer_text_region = 0 - indexer_b_s = 0 - - for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - x,y,w,h = cv2.boundingRect(textline_coords) - - total_bb_coordinates.append([x,y,w,h]) - - h2w_ratio = h/float(w) - - img_poly_on_img = np.copy(img) - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - img_crop[mask_poly==0] = 255 - - self.logger.debug("processing %d lines for '%s'", - len(cropped_lines), nn.attrib['id']) - if h2w_ratio > 0.1: - cropped_lines.append(resize_image(img_crop, - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width) ) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')( - imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, - skip_special_tokens=True, - clean_up_tokenization_spaces=False) - - extracted_texts = extracted_texts + generated_text_merged - - else: - splited_images, _ = return_textlines_split_if_needed(img_crop, None) - #print(splited_images) - if splited_images: - cropped_lines.append(resize_image(splited_images[0], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')( - imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, - skip_special_tokens=True, - clean_up_tokenization_spaces=False) - - extracted_texts = extracted_texts + generated_text_merged - - - cropped_lines.append(resize_image(splited_images[1], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(-1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')( - imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, - skip_special_tokens=True, - clean_up_tokenization_spaces=False) - - extracted_texts = extracted_texts + generated_text_merged - - else: - cropped_lines.append(img_crop) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 + for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)): + for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)): + cropped_lines_region_indexer.append(n_region) - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 + coords = line.find('{%s}Coords' % page_ns) + if coords is None: + self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id']) + continue + poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int) + cont = poly[:, np.newaxis] + xywh = xywh_from_polygon(poly) + x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h'] - pixel_values_merged = self.model_zoo.get('trocr_processor')( - imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, - skip_special_tokens=True, - clean_up_tokenization_spaces=False) + total_bb_coordinates.append([x, y, w, h]) - extracted_texts = extracted_texts + generated_text_merged + img_crop = img[y: y + h, x: x + w] + mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8) + mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1) + img_crop[mask_poly == 0] = 255 # FIXME: or median color? - indexer_text_region = indexer_text_region +1 + if h > 0.1 * w: + cropped_lines.append(resize_image(img_crop, + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width) ) + cropped_lines_meging_indexing.append(0) + else: + splited_images, _ = return_textlines_split_if_needed(img_crop, None) + if splited_images: + cropped_lines.append(resize_image(splited_images[0], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(1) + cropped_lines.append(resize_image(splited_images[1], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(-1) + else: + cropped_lines.append(img_crop) + cropped_lines_meging_indexing.append(0) - if indexer_b_s!=0: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - + + self.logger.debug("processing %d lines for %d regions", + len(cropped_lines), len(set(cropped_lines_region_indexer))) + for imgs in batched(cropped_lines, self.b_s): pixel_values_merged = self.model_zoo.get('trocr_processor')( imgs, return_tensors="pt").pixel_values generated_ids_merged = self.model_zoo.get('ocr').generate( @@ -235,40 +149,15 @@ class Eynollah_ocr(Eynollah): generated_ids_merged, skip_special_tokens=True, clean_up_tokenization_spaces=False) - extracted_texts = extracted_texts + generated_text_merged - - ####extracted_texts = [] - ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - ####for i in range(n_iterations): - ####if i==(n_iterations-1): - ####n_start = i*self.b_s - ####imgs = cropped_lines[n_start:] - ####else: - ####n_start = i*self.b_s - ####n_end = (i+1)*self.b_s - ####imgs = cropped_lines[n_start:n_end] - ####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - ####generated_ids_merged = self.model_ocr.generate( - #### pixel_values_merged.to(self.device)) - ####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - #### generated_ids_merged, skip_special_tokens=True) - - ####extracted_texts = extracted_texts + generated_text_merged - del cropped_lines gc.collect() extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - #print(extracted_texts_merged, len(extracted_texts_merged)) + if cropped_lines_meging_indexing[ind] == 0 + else extracted_texts[ind] + " " + extracted_texts[ind + 1] + for ind in range(len(cropped_lines_meging_indexing)) + if cropped_lines_meging_indexing[ind] >= 0] return EynollahOcrResult( extracted_texts_merged=extracted_texts_merged, diff --git a/src/eynollah/utils/utils_ocr.py b/src/eynollah/utils/utils_ocr.py index 93d1137..6914fee 100644 --- a/src/eynollah/utils/utils_ocr.py +++ b/src/eynollah/utils/utils_ocr.py @@ -1,5 +1,6 @@ import math import copy +from itertools import islice import numpy as np import cv2 @@ -502,3 +503,8 @@ def return_rnn_cnn_ocr_of_given_textlines(image, ocr_textline_in_textregion.append(text_textline) ocr_all_textlines.append(ocr_textline_in_textregion) return ocr_all_textlines + +def batched(iterable, n): + iterator = iter(iterable) + while batch := tuple(islice(iterator, n)): + yield batch