trocr: simplify, batch over entire page…

- batching over entire page instead of region-wise
  (underfilling batches)
- avoid copied redundant code
This commit is contained in:
Robert Sachunsky 2026-05-21 15:48:21 +02:00
parent d50bd7c650
commit 1d67e65f11
2 changed files with 51 additions and 156 deletions

View file

@ -14,6 +14,7 @@ from cv2.typing import MatLike
from xml.etree import ElementTree as ET
from PIL import Image, ImageDraw
import numpy as np
from ocrd_utils import polygon_from_points, xywh_from_polygon
from .eynollah import Eynollah
@ -31,6 +32,7 @@ from .utils.utils_ocr import (
preprocess_and_resize_image_for_ocrcnn_model,
return_textlines_split_if_needed,
rotate_image_with_padding,
batched,
)
# TODO: refine typing
@ -90,125 +92,55 @@ class Eynollah_ocr(Eynollah):
) -> EynollahOcrResult:
total_bb_coordinates = []
cropped_lines = []
cropped_lines_region_indexer = []
cropped_lines_meging_indexing = []
extracted_texts = []
indexer_text_region = 0
indexer_b_s = 0
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
cropped_lines_region_indexer.append(n_region)
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
for child_textlines in child_textregion:
if child_textlines.tag.endswith("Coords"):
cropped_lines_region_indexer.append(indexer_text_region)
p_h=child_textlines.attrib['points'].split(' ')
textline_coords = np.array( [ [int(x.split(',')[0]),
int(x.split(',')[1]) ]
for x in p_h] )
x,y,w,h = cv2.boundingRect(textline_coords)
coords = line.find('{%s}Coords' % page_ns)
if coords is None:
self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
continue
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
cont = poly[:, np.newaxis]
xywh = xywh_from_polygon(poly)
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
total_bb_coordinates.append([x, y, w, h])
h2w_ratio = h/float(w)
img_crop = img[y: y + h, x: x + w]
mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8)
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
img_crop[mask_poly == 0] = 255 # FIXME: or median color?
img_poly_on_img = np.copy(img)
mask_poly = np.zeros(img.shape)
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
mask_poly = mask_poly[y:y+h, x:x+w, :]
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
img_crop[mask_poly==0] = 255
self.logger.debug("processing %d lines for '%s'",
len(cropped_lines), nn.attrib['id'])
if h2w_ratio > 0.1:
if h > 0.1 * w:
cropped_lines.append(resize_image(img_crop,
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width) )
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
else:
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
#print(splited_images)
if splited_images:
cropped_lines.append(resize_image(splited_images[0],
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
cropped_lines.append(resize_image(splited_images[1],
tr_ocr_input_height_and_width,
tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(-1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
else:
cropped_lines.append(img_crop)
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
self.logger.debug("processing %d lines for %d regions",
len(cropped_lines), len(set(cropped_lines_region_indexer)))
for imgs in batched(cropped_lines, self.b_s):
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
@ -217,58 +149,15 @@ class Eynollah_ocr(Eynollah):
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
indexer_text_region = indexer_text_region +1
if indexer_b_s!=0:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device))
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
generated_ids_merged,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged
####extracted_texts = []
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
####for i in range(n_iterations):
####if i==(n_iterations-1):
####n_start = i*self.b_s
####imgs = cropped_lines[n_start:]
####else:
####n_start = i*self.b_s
####n_end = (i+1)*self.b_s
####imgs = cropped_lines[n_start:n_end]
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
####generated_ids_merged = self.model_ocr.generate(
#### pixel_values_merged.to(self.device))
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
#### generated_ids_merged, skip_special_tokens=True)
####extracted_texts = extracted_texts + generated_text_merged
del cropped_lines
gc.collect()
extracted_texts_merged = [extracted_texts[ind]
if cropped_lines_meging_indexing[ind] == 0
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
if cropped_lines_meging_indexing[ind]==1
else None
for ind in range(len(cropped_lines_meging_indexing))]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
#print(extracted_texts_merged, len(extracted_texts_merged))
for ind in range(len(cropped_lines_meging_indexing))
if cropped_lines_meging_indexing[ind] >= 0]
return EynollahOcrResult(
extracted_texts_merged=extracted_texts_merged,

View file

@ -1,5 +1,6 @@
import math
import copy
from itertools import islice
import numpy as np
import cv2
@ -502,3 +503,8 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
ocr_textline_in_textregion.append(text_textline)
ocr_all_textlines.append(ocr_textline_in_textregion)
return ocr_all_textlines
def batched(iterable, n):
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
yield batch