mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
trocr: simplify, batch over entire page…
- batching over entire page instead of region-wise (underfilling batches) - avoid copied redundant code
This commit is contained in:
parent
d50bd7c650
commit
1d67e65f11
2 changed files with 51 additions and 156 deletions
|
|
@ -14,6 +14,7 @@ from cv2.typing import MatLike
|
|||
from xml.etree import ElementTree as ET
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
from ocrd_utils import polygon_from_points, xywh_from_polygon
|
||||
|
||||
|
||||
from .eynollah import Eynollah
|
||||
|
|
@ -31,6 +32,7 @@ from .utils.utils_ocr import (
|
|||
preprocess_and_resize_image_for_ocrcnn_model,
|
||||
return_textlines_split_if_needed,
|
||||
rotate_image_with_padding,
|
||||
batched,
|
||||
)
|
||||
|
||||
# TODO: refine typing
|
||||
|
|
@ -90,125 +92,55 @@ class Eynollah_ocr(Eynollah):
|
|||
) -> EynollahOcrResult:
|
||||
|
||||
total_bb_coordinates = []
|
||||
|
||||
|
||||
cropped_lines = []
|
||||
cropped_lines_region_indexer = []
|
||||
cropped_lines_meging_indexing = []
|
||||
|
||||
extracted_texts = []
|
||||
|
||||
indexer_text_region = 0
|
||||
indexer_b_s = 0
|
||||
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||
cropped_lines_region_indexer.append(n_region)
|
||||
|
||||
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
|
||||
for child_textregion in nn:
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
coords = line.find('{%s}Coords' % page_ns)
|
||||
if coords is None:
|
||||
self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
|
||||
continue
|
||||
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
|
||||
cont = poly[:, np.newaxis]
|
||||
xywh = xywh_from_polygon(poly)
|
||||
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
|
||||
|
||||
for child_textlines in child_textregion:
|
||||
if child_textlines.tag.endswith("Coords"):
|
||||
cropped_lines_region_indexer.append(indexer_text_region)
|
||||
p_h=child_textlines.attrib['points'].split(' ')
|
||||
textline_coords = np.array( [ [int(x.split(',')[0]),
|
||||
int(x.split(',')[1]) ]
|
||||
for x in p_h] )
|
||||
x,y,w,h = cv2.boundingRect(textline_coords)
|
||||
total_bb_coordinates.append([x, y, w, h])
|
||||
|
||||
total_bb_coordinates.append([x,y,w,h])
|
||||
img_crop = img[y: y + h, x: x + w]
|
||||
mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
|
||||
img_crop[mask_poly == 0] = 255 # FIXME: or median color?
|
||||
|
||||
h2w_ratio = h/float(w)
|
||||
|
||||
img_poly_on_img = np.copy(img)
|
||||
mask_poly = np.zeros(img.shape)
|
||||
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
||||
|
||||
mask_poly = mask_poly[y:y+h, x:x+w, :]
|
||||
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
|
||||
img_crop[mask_poly==0] = 255
|
||||
|
||||
self.logger.debug("processing %d lines for '%s'",
|
||||
len(cropped_lines), nn.attrib['id'])
|
||||
if h2w_ratio > 0.1:
|
||||
if h > 0.1 * w:
|
||||
cropped_lines.append(resize_image(img_crop,
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width) )
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
indexer_b_s+=1
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
||||
imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
else:
|
||||
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
|
||||
#print(splited_images)
|
||||
if splited_images:
|
||||
cropped_lines.append(resize_image(splited_images[0],
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width))
|
||||
cropped_lines_meging_indexing.append(1)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
||||
imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
|
||||
cropped_lines.append(resize_image(splited_images[1],
|
||||
tr_ocr_input_height_and_width,
|
||||
tr_ocr_input_height_and_width))
|
||||
cropped_lines_meging_indexing.append(-1)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
||||
imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
else:
|
||||
cropped_lines.append(img_crop)
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
indexer_b_s+=1
|
||||
|
||||
if indexer_b_s==self.b_s:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
self.logger.debug("processing %d lines for %d regions",
|
||||
len(cropped_lines), len(set(cropped_lines_region_indexer)))
|
||||
for imgs in batched(cropped_lines, self.b_s):
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
||||
imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
|
|
@ -217,58 +149,15 @@ class Eynollah_ocr(Eynollah):
|
|||
generated_ids_merged,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
indexer_text_region = indexer_text_region +1
|
||||
|
||||
if indexer_b_s!=0:
|
||||
imgs = cropped_lines[:]
|
||||
cropped_lines = []
|
||||
indexer_b_s = 0
|
||||
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
||||
imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
####extracted_texts = []
|
||||
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
||||
|
||||
####for i in range(n_iterations):
|
||||
####if i==(n_iterations-1):
|
||||
####n_start = i*self.b_s
|
||||
####imgs = cropped_lines[n_start:]
|
||||
####else:
|
||||
####n_start = i*self.b_s
|
||||
####n_end = (i+1)*self.b_s
|
||||
####imgs = cropped_lines[n_start:n_end]
|
||||
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
||||
####generated_ids_merged = self.model_ocr.generate(
|
||||
#### pixel_values_merged.to(self.device))
|
||||
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
#### generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
####extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
del cropped_lines
|
||||
gc.collect()
|
||||
|
||||
extracted_texts_merged = [extracted_texts[ind]
|
||||
if cropped_lines_meging_indexing[ind]==0
|
||||
else extracted_texts[ind]+" "+extracted_texts[ind+1]
|
||||
if cropped_lines_meging_indexing[ind]==1
|
||||
else None
|
||||
for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
||||
#print(extracted_texts_merged, len(extracted_texts_merged))
|
||||
if cropped_lines_meging_indexing[ind] == 0
|
||||
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
|
||||
for ind in range(len(cropped_lines_meging_indexing))
|
||||
if cropped_lines_meging_indexing[ind] >= 0]
|
||||
|
||||
return EynollahOcrResult(
|
||||
extracted_texts_merged=extracted_texts_merged,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import math
|
||||
import copy
|
||||
from itertools import islice
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
|
@ -502,3 +503,8 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
|
|||
ocr_textline_in_textregion.append(text_textline)
|
||||
ocr_all_textlines.append(ocr_textline_in_textregion)
|
||||
return ocr_all_textlines
|
||||
|
||||
def batched(iterable, n):
|
||||
iterator = iter(iterable)
|
||||
while batch := tuple(islice(iterator, n)):
|
||||
yield batch
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue