trocr inference is integrated - works on CPU cause seg fault on GPU

This commit is contained in:
vahidrezanezhad 2026-02-18 15:04:54 +01:00 committed by kba
parent a11c833fc1
commit 499e3d0715
5 changed files with 39 additions and 13 deletions

View file

@ -345,7 +345,7 @@ class Eynollah_ocr(Eynollah):
if out_image_with_text: if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text) draw = ImageDraw.Draw(image_text)
font = get_font() font = get_font(font_size=40)
for indexer_text, bb_ind in enumerate(total_bb_coordinates): for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0] x_bb = bb_ind[0]

View file

@ -6,6 +6,7 @@ from pathlib import Path
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import cv2 import cv2
import numpy as np import numpy as np
from eynollah.utils.font import get_font
from .gt_gen_utils import ( from .gt_gen_utils import (
filter_contours_area_of_image, filter_contours_area_of_image,
@ -552,8 +553,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
else: else:
xml_files_ind = [xml_file] xml_files_ind = [xml_file]
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! ###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = ImageFont.truetype(font_path, 40) font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
for ind_xml in tqdm(xml_files_ind): for ind_xml in tqdm(xml_files_ind):
indexer = 0 indexer = 0
@ -590,11 +591,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
is_vertical = h > 2*w # Check orientation is_vertical = h > 2*w # Check orientation
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
if is_vertical: if is_vertical:
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
text_draw = ImageDraw.Draw(text_img) text_draw = ImageDraw.Draw(text_img)

View file

@ -8,7 +8,7 @@ from shapely import geometry
from pathlib import Path from pathlib import Path
from PIL import ImageFont from PIL import ImageFont
from ocrd_utils import bbox_from_points from ocrd_utils import bbox_from_points
from eynollah.utils.font import get_font
KERNEL = np.ones((5, 5), np.uint8) KERNEL = np.ones((5, 5), np.uint8)
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15' NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
@ -352,11 +352,11 @@ def get_textline_contours_and_ocr_text(xml_file):
ocr_textlines.append(ocr_text_in[0]) ocr_textlines.append(ocr_text_in[0])
return co_use_case, y_len, x_len, ocr_textlines return co_use_case, y_len, x_len, ocr_textlines
def fit_text_single_line(draw, text, font_path, max_width, max_height): def fit_text_single_line(draw, text, max_width, max_height):
initial_font_size = 50 initial_font_size = 50
font_size = initial_font_size font_size = initial_font_size
while font_size > 10: # Minimum font size while font_size > 10: # Minimum font size
font = ImageFont.truetype(font_path, font_size) font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
text_width = text_bbox[2] - text_bbox[0] text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1] text_height = text_bbox[3] - text_bbox[1]
@ -366,7 +366,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
font_size -= 2 # Reduce font size and retry font_size -= 2 # Reduce font size and retry
return ImageFont.truetype(font_path, 10) # Smallest font fallback return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
def get_layout_contours_for_visualization(xml_file): def get_layout_contours_for_visualization(xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))

View file

@ -132,15 +132,31 @@ class SBBPredict:
self.model = Model( self.model = Model(
self.model.get_layer(name = "image").input, self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output) self.model.get_layer(name = "dense2").output)
assert isinstance(self.model, Model)
elif self.task == "transformer-ocr":
import torch
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
if self.cpu:
self.device = torch.device('cpu')
else:
self.device = torch.device('cuda:0')
self.model.to(self.device)
assert isinstance(self.model, torch.nn.Module)
else: else:
self.model = load_model(self.model_dir, compile=False, self.model = load_model(self.model_dir, compile=False,
custom_objects={"PatchEncoder": PatchEncoder, custom_objects={"PatchEncoder": PatchEncoder,
"Patches": Patches}) "Patches": Patches})
assert isinstance(self.model, Model)
##if self.weights_dir!=None: ##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir) ##self.model.load_weights(self.weights_dir)
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order': if self.task != 'classification' and self.task != 'reading_order':
last = self.model.layers[-1] last = self.model.layers[-1]
self.img_height = last.output_shape[1] self.img_height = last.output_shape[1]
@ -231,6 +247,13 @@ class SBBPredict:
pred_texts = pred_texts[0].replace("[UNK]", "") pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts return pred_texts
elif self.task == "transformer-ocr":
from PIL import Image
image = Image.open(image_dir).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values.to(self.device))
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.task == 'reading_order': elif self.task == 'reading_order':
img_height = self.config_params_model['input_height'] img_height = self.config_params_model['input_height']
@ -566,6 +589,8 @@ class SBBPredict:
cv2.imwrite(self.save,res) cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr": elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}") print(f"Detected text: {res}")
elif self.task == "transformer-ocr":
print(f"Detected text: {res}")
else: else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save: if self.save:
@ -672,7 +697,7 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil
with open(os.path.join(model,'config.json')) as f: with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f) config_params_model = json.load(f)
task = config_params_model['task'] task = config_params_model['task']
if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]: if task not in ['classification', 'reading_order', "cnn-rnn-ocr", "transformer-ocr"]:
assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s" assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o" assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
x = SBBPredict(image, dir_in, model, task, config_params_model, x = SBBPredict(image, dir_in, model, task, config_params_model,

View file

@ -9,8 +9,8 @@ else:
import importlib.resources as importlib_resources import importlib.resources as importlib_resources
def get_font(): def get_font(font_size):
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
with importlib_resources.as_file(font) as font: with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40) return ImageFont.truetype(font=font, size=font_size)