mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
trocr: simplify, batch over entire page…
- batching over entire page instead of region-wise (underfilling batches) - avoid copied redundant code
This commit is contained in:
parent
d50bd7c650
commit
1d67e65f11
2 changed files with 51 additions and 156 deletions
|
|
@ -14,6 +14,7 @@ from cv2.typing import MatLike
|
||||||
from xml.etree import ElementTree as ET
|
from xml.etree import ElementTree as ET
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from ocrd_utils import polygon_from_points, xywh_from_polygon
|
||||||
|
|
||||||
|
|
||||||
from .eynollah import Eynollah
|
from .eynollah import Eynollah
|
||||||
|
|
@ -31,6 +32,7 @@ from .utils.utils_ocr import (
|
||||||
preprocess_and_resize_image_for_ocrcnn_model,
|
preprocess_and_resize_image_for_ocrcnn_model,
|
||||||
return_textlines_split_if_needed,
|
return_textlines_split_if_needed,
|
||||||
rotate_image_with_padding,
|
rotate_image_with_padding,
|
||||||
|
batched,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: refine typing
|
# TODO: refine typing
|
||||||
|
|
@ -90,143 +92,55 @@ class Eynollah_ocr(Eynollah):
|
||||||
) -> EynollahOcrResult:
|
) -> EynollahOcrResult:
|
||||||
|
|
||||||
total_bb_coordinates = []
|
total_bb_coordinates = []
|
||||||
|
|
||||||
|
|
||||||
cropped_lines = []
|
cropped_lines = []
|
||||||
cropped_lines_region_indexer = []
|
cropped_lines_region_indexer = []
|
||||||
cropped_lines_meging_indexing = []
|
cropped_lines_meging_indexing = []
|
||||||
|
|
||||||
extracted_texts = []
|
extracted_texts = []
|
||||||
|
|
||||||
indexer_text_region = 0
|
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||||
indexer_b_s = 0
|
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||||
|
cropped_lines_region_indexer.append(n_region)
|
||||||
|
|
||||||
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
|
coords = line.find('{%s}Coords' % page_ns)
|
||||||
for child_textregion in nn:
|
if coords is None:
|
||||||
if child_textregion.tag.endswith("TextLine"):
|
self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id'])
|
||||||
|
continue
|
||||||
|
poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int)
|
||||||
|
cont = poly[:, np.newaxis]
|
||||||
|
xywh = xywh_from_polygon(poly)
|
||||||
|
x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h']
|
||||||
|
|
||||||
for child_textlines in child_textregion:
|
total_bb_coordinates.append([x, y, w, h])
|
||||||
if child_textlines.tag.endswith("Coords"):
|
|
||||||
cropped_lines_region_indexer.append(indexer_text_region)
|
|
||||||
p_h=child_textlines.attrib['points'].split(' ')
|
|
||||||
textline_coords = np.array( [ [int(x.split(',')[0]),
|
|
||||||
int(x.split(',')[1]) ]
|
|
||||||
for x in p_h] )
|
|
||||||
x,y,w,h = cv2.boundingRect(textline_coords)
|
|
||||||
|
|
||||||
total_bb_coordinates.append([x,y,w,h])
|
img_crop = img[y: y + h, x: x + w]
|
||||||
|
mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8)
|
||||||
|
mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1)
|
||||||
|
img_crop[mask_poly == 0] = 255 # FIXME: or median color?
|
||||||
|
|
||||||
h2w_ratio = h/float(w)
|
if h > 0.1 * w:
|
||||||
|
cropped_lines.append(resize_image(img_crop,
|
||||||
img_poly_on_img = np.copy(img)
|
tr_ocr_input_height_and_width,
|
||||||
mask_poly = np.zeros(img.shape)
|
tr_ocr_input_height_and_width) )
|
||||||
mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1))
|
cropped_lines_meging_indexing.append(0)
|
||||||
|
else:
|
||||||
mask_poly = mask_poly[y:y+h, x:x+w, :]
|
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
|
||||||
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
|
if splited_images:
|
||||||
img_crop[mask_poly==0] = 255
|
cropped_lines.append(resize_image(splited_images[0],
|
||||||
|
tr_ocr_input_height_and_width,
|
||||||
self.logger.debug("processing %d lines for '%s'",
|
tr_ocr_input_height_and_width))
|
||||||
len(cropped_lines), nn.attrib['id'])
|
cropped_lines_meging_indexing.append(1)
|
||||||
if h2w_ratio > 0.1:
|
cropped_lines.append(resize_image(splited_images[1],
|
||||||
cropped_lines.append(resize_image(img_crop,
|
tr_ocr_input_height_and_width,
|
||||||
tr_ocr_input_height_and_width,
|
tr_ocr_input_height_and_width))
|
||||||
tr_ocr_input_height_and_width) )
|
cropped_lines_meging_indexing.append(-1)
|
||||||
cropped_lines_meging_indexing.append(0)
|
else:
|
||||||
indexer_b_s+=1
|
cropped_lines.append(img_crop)
|
||||||
if indexer_b_s==self.b_s:
|
cropped_lines_meging_indexing.append(0)
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
|
||||||
imgs, return_tensors="pt").pixel_values
|
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
|
||||||
pixel_values_merged.to(self.device))
|
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
generated_ids_merged,
|
|
||||||
skip_special_tokens=True,
|
|
||||||
clean_up_tokenization_spaces=False)
|
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
else:
|
|
||||||
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
|
|
||||||
#print(splited_images)
|
|
||||||
if splited_images:
|
|
||||||
cropped_lines.append(resize_image(splited_images[0],
|
|
||||||
tr_ocr_input_height_and_width,
|
|
||||||
tr_ocr_input_height_and_width))
|
|
||||||
cropped_lines_meging_indexing.append(1)
|
|
||||||
indexer_b_s+=1
|
|
||||||
|
|
||||||
if indexer_b_s==self.b_s:
|
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
|
||||||
imgs, return_tensors="pt").pixel_values
|
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
|
||||||
pixel_values_merged.to(self.device))
|
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
generated_ids_merged,
|
|
||||||
skip_special_tokens=True,
|
|
||||||
clean_up_tokenization_spaces=False)
|
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
|
|
||||||
cropped_lines.append(resize_image(splited_images[1],
|
self.logger.debug("processing %d lines for %d regions",
|
||||||
tr_ocr_input_height_and_width,
|
len(cropped_lines), len(set(cropped_lines_region_indexer)))
|
||||||
tr_ocr_input_height_and_width))
|
for imgs in batched(cropped_lines, self.b_s):
|
||||||
cropped_lines_meging_indexing.append(-1)
|
|
||||||
indexer_b_s+=1
|
|
||||||
|
|
||||||
if indexer_b_s==self.b_s:
|
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
|
||||||
imgs, return_tensors="pt").pixel_values
|
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
|
||||||
pixel_values_merged.to(self.device))
|
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
generated_ids_merged,
|
|
||||||
skip_special_tokens=True,
|
|
||||||
clean_up_tokenization_spaces=False)
|
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
else:
|
|
||||||
cropped_lines.append(img_crop)
|
|
||||||
cropped_lines_meging_indexing.append(0)
|
|
||||||
indexer_b_s+=1
|
|
||||||
|
|
||||||
if indexer_b_s==self.b_s:
|
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
|
||||||
imgs, return_tensors="pt").pixel_values
|
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
|
||||||
pixel_values_merged.to(self.device))
|
|
||||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
generated_ids_merged,
|
|
||||||
skip_special_tokens=True,
|
|
||||||
clean_up_tokenization_spaces=False)
|
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
indexer_text_region = indexer_text_region +1
|
|
||||||
|
|
||||||
if indexer_b_s!=0:
|
|
||||||
imgs = cropped_lines[:]
|
|
||||||
cropped_lines = []
|
|
||||||
indexer_b_s = 0
|
|
||||||
|
|
||||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
||||||
imgs, return_tensors="pt").pixel_values
|
imgs, return_tensors="pt").pixel_values
|
||||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||||
|
|
@ -235,40 +149,15 @@ class Eynollah_ocr(Eynollah):
|
||||||
generated_ids_merged,
|
generated_ids_merged,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
clean_up_tokenization_spaces=False)
|
clean_up_tokenization_spaces=False)
|
||||||
|
|
||||||
extracted_texts = extracted_texts + generated_text_merged
|
extracted_texts = extracted_texts + generated_text_merged
|
||||||
|
|
||||||
####extracted_texts = []
|
|
||||||
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
|
||||||
|
|
||||||
####for i in range(n_iterations):
|
|
||||||
####if i==(n_iterations-1):
|
|
||||||
####n_start = i*self.b_s
|
|
||||||
####imgs = cropped_lines[n_start:]
|
|
||||||
####else:
|
|
||||||
####n_start = i*self.b_s
|
|
||||||
####n_end = (i+1)*self.b_s
|
|
||||||
####imgs = cropped_lines[n_start:n_end]
|
|
||||||
####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values
|
|
||||||
####generated_ids_merged = self.model_ocr.generate(
|
|
||||||
#### pixel_values_merged.to(self.device))
|
|
||||||
####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
|
||||||
#### generated_ids_merged, skip_special_tokens=True)
|
|
||||||
|
|
||||||
####extracted_texts = extracted_texts + generated_text_merged
|
|
||||||
|
|
||||||
del cropped_lines
|
del cropped_lines
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
extracted_texts_merged = [extracted_texts[ind]
|
extracted_texts_merged = [extracted_texts[ind]
|
||||||
if cropped_lines_meging_indexing[ind]==0
|
if cropped_lines_meging_indexing[ind] == 0
|
||||||
else extracted_texts[ind]+" "+extracted_texts[ind+1]
|
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
|
||||||
if cropped_lines_meging_indexing[ind]==1
|
for ind in range(len(cropped_lines_meging_indexing))
|
||||||
else None
|
if cropped_lines_meging_indexing[ind] >= 0]
|
||||||
for ind in range(len(cropped_lines_meging_indexing))]
|
|
||||||
|
|
||||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
|
||||||
#print(extracted_texts_merged, len(extracted_texts_merged))
|
|
||||||
|
|
||||||
return EynollahOcrResult(
|
return EynollahOcrResult(
|
||||||
extracted_texts_merged=extracted_texts_merged,
|
extracted_texts_merged=extracted_texts_merged,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import math
|
import math
|
||||||
import copy
|
import copy
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
@ -502,3 +503,8 @@ def return_rnn_cnn_ocr_of_given_textlines(image,
|
||||||
ocr_textline_in_textregion.append(text_textline)
|
ocr_textline_in_textregion.append(text_textline)
|
||||||
ocr_all_textlines.append(ocr_textline_in_textregion)
|
ocr_all_textlines.append(ocr_textline_in_textregion)
|
||||||
return ocr_all_textlines
|
return ocr_all_textlines
|
||||||
|
|
||||||
|
def batched(iterable, n):
|
||||||
|
iterator = iter(iterable)
|
||||||
|
while batch := tuple(islice(iterator, n)):
|
||||||
|
yield batch
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue