trocr: simplify, batch over entire page…

- batching over entire page instead of region-wise
  (underfilling batches)
- avoid copied redundant code
This commit is contained in:
Robert Sachunsky 2026-05-21 15:48:21 +02:00
parent d50bd7c650
commit 1d67e65f11
2 changed files with 51 additions and 156 deletions

View file

@ -14,6 +14,7 @@ from cv2.typing import MatLike
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
import numpy as np import numpy as np
from ocrd_utils import polygon_from_points, xywh_from_polygon
from .eynollah import Eynollah from .eynollah import Eynollah
@ -31,6 +32,7 @@ from .utils.utils_ocr import (
preprocess_and_resize_image_for_ocrcnn_model, preprocess_and_resize_image_for_ocrcnn_model,
return_textlines_split_if_needed, return_textlines_split_if_needed,
rotate_image_with_padding, rotate_image_with_padding,
batched,
) )
# TODO: refine typing # TODO: refine typing
@ -90,143 +92,55 @@ class Eynollah_ocr(Eynollah):
) -> EynollahOcrResult: ) -> EynollahOcrResult:
total_bb_coordinates = [] total_bb_coordinates = []
cropped_lines = [] cropped_lines = []
cropped_lines_region_indexer = [] cropped_lines_region_indexer = []
cropped_lines_meging_indexing = [] cropped_lines_meging_indexing = []
extracted_texts = [] extracted_texts = []
indexer_text_region = 0 for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
indexer_b_s = 0 for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
cropped_lines_region_indexer.append(n_region)
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): coords = line.find('{%s}Coords' % page_ns)
for child_textregion in nn: if coords is None:
if child_textregion.tag.endswith("TextLine"): self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
continue
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
cont = poly[:, np.newaxis]
xywh = xywh_from_polygon(poly)
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
for child_textlines in child_textregion: total_bb_coordinates.append([x, y, w, h])
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h=child_textlines.attrib['points'].split(' ')
textline_coords = np.array( [ [int(x.split(',')[0]),
int(x.split(',')[1]) ]
for x in p_h] )
x,y,w,h = cv2.boundingRect(textline_coords)
total_bb_coordinates.append([x,y,w,h]) img_crop = img[y: y + h, x: x + w]
mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8)
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
img_crop[mask_poly == 0] = 255 # FIXME: or median color?
h2w_ratio = h/float(w) if h > 0.1 * w:
cropped_lines.append(resize_image(img_crop,
img_poly_on_img = np.copy(img) tr_ocr_input_height_and_width,
mask_poly = np.zeros(img.shape) tr_ocr_input_height_and_width) )
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) cropped_lines_meging_indexing.append(0)
else:
mask_poly = mask_poly[y:y+h, x:x+w, :] splited_images, _ = return_textlines_split_if_needed(img_crop, None)
img_crop = img_poly_on_img[y:y+h, x:x+w, :] if splited_images:
img_crop[mask_poly==0] = 255 cropped_lines.append(resize_image(splited_images[0],
tr_ocr_input_height_and_width,
self.logger.debug("processing %d lines for '%s'", tr_ocr_input_height_and_width))
len(cropped_lines), nn.attrib['id']) cropped_lines_meging_indexing.append(1)
if h2w_ratio > 0.1: cropped_lines.append(resize_image(splited_images[1],
cropped_lines.append(resize_image(img_crop, tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width, tr_ocr_input_height_and_width))
tr_ocr_input_height_and_width) ) cropped_lines_meging_indexing.append(-1)
cropped_lines_meging_indexing.append(0) else:
indexer_b_s+=1 cropped_lines.append(img_crop)
if indexer_b_s==self.b_s: cropped_lines_meging_indexing.append(0)
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
else:
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
#print(splited_images)
if splited_images:
cropped_lines.append(resize_image(splited_images[0],
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
cropped_lines.append(resize_image(splited_images[1], self.logger.debug("processing %d lines for %d regions",
tr_ocr_input_height_and_width, len(cropped_lines), len(set(cropped_lines_region_indexer)))
tr_ocr_input_height_and_width)) for imgs in batched(cropped_lines, self.b_s):
cropped_lines_meging_indexing.append(-1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
else:
cropped_lines.append(img_crop)
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
indexer_text_region = indexer_text_region +1
if indexer_b_s!=0:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')( pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate( generated_ids_merged = self.model_zoo.get('ocr').generate(
@ -235,40 +149,15 @@ class Eynollah_ocr(Eynollah):
generated_ids_merged, generated_ids_merged,
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False) clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged extracted_texts = extracted_texts + generated_text_merged
####extracted_texts = []
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
####for i in range(n_iterations):
####if i==(n_iterations-1):
####n_start = i*self.b_s
####imgs = cropped_lines[n_start:]
####else:
####n_start = i*self.b_s
####n_end = (i+1)*self.b_s
####imgs = cropped_lines[n_start:n_end]
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
####generated_ids_merged = self.model_ocr.generate(
#### pixel_values_merged.to(self.device))
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
#### generated_ids_merged, skip_special_tokens=True)
####extracted_texts = extracted_texts + generated_text_merged
del cropped_lines del cropped_lines
gc.collect() gc.collect()
extracted_texts_merged = [extracted_texts[ind] extracted_texts_merged = [extracted_texts[ind]
if cropped_lines_meging_indexing[ind]==0 if cropped_lines_meging_indexing[ind] == 0
else extracted_texts[ind]+" "+extracted_texts[ind+1] else extracted_texts[ind] + " " + extracted_texts[ind + 1]
if cropped_lines_meging_indexing[ind]==1 for ind in range(len(cropped_lines_meging_indexing))
else None if cropped_lines_meging_indexing[ind] >= 0]
for ind in range(len(cropped_lines_meging_indexing))]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
#print(extracted_texts_merged, len(extracted_texts_merged))
return EynollahOcrResult( return EynollahOcrResult(
extracted_texts_merged=extracted_texts_merged, extracted_texts_merged=extracted_texts_merged,

View file

@ -1,5 +1,6 @@
import math import math
import copy import copy
from itertools import islice
import numpy as np import numpy as np
import cv2 import cv2
@ -502,3 +503,8 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
ocr_textline_in_textregion.append(text_textline) ocr_textline_in_textregion.append(text_textline)
ocr_all_textlines.append(ocr_textline_in_textregion) ocr_all_textlines.append(ocr_textline_in_textregion)
return ocr_all_textlines return ocr_all_textlines
def batched(iterable, n):
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
yield batch