trocr inference is integrated - works on CPU cause seg fault on GPU

This commit is contained in:
vahidrezanezhad 2026-02-18 15:04:54 +01:00
parent 733462381c
commit b426f7f152
5 changed files with 52 additions and 17 deletions

View file

@ -658,7 +658,7 @@ class Eynollah_ocr:
if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text)
font = get_font()
font = get_font(font_size=40)
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0]

View file

@ -6,6 +6,7 @@ from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
from eynollah.utils.font import get_font
from eynollah.training.gt_gen_utils import (
filter_contours_area_of_image,
@ -514,8 +515,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
else:
xml_files_ind = [xml_file]
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = ImageFont.truetype(font_path, 40)
###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
for ind_xml in tqdm(xml_files_ind):
indexer = 0
@ -552,11 +553,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
is_vertical = h > 2*w # Check orientation
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) )
font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
if is_vertical:
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8))
vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
text_draw = ImageDraw.Draw(text_img)

View file

@ -7,7 +7,7 @@ import cv2
from shapely import geometry
from pathlib import Path
from PIL import ImageFont
from eynollah.utils.font import get_font
KERNEL = np.ones((5, 5), np.uint8)
@ -350,11 +350,11 @@ def get_textline_contours_and_ocr_text(xml_file):
ocr_textlines.append(ocr_text_in[0])
return co_use_case, y_len, x_len, ocr_textlines
def fit_text_single_line(draw, text, font_path, max_width, max_height):
def fit_text_single_line(draw, text, max_width, max_height):
initial_font_size = 50
font_size = initial_font_size
while font_size > 10: # Minimum font size
font = ImageFont.truetype(font_path, font_size)
font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
@ -364,7 +364,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
font_size -= 2 # Reduce font size and retry
return ImageFont.truetype(font_path, 10) # Smallest font fallback
return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
def get_layout_contours_for_visualization(xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))

View file

@ -170,6 +170,25 @@ class sbb_predict:
self.model = tf.keras.models.Model(
self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output)
assert isinstance(self.model, Model)
elif self.task == "trocr":
import torch
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
if self.cpu:
self.device = torch.device('cpu')
else:
self.device = torch.device('cuda:0')
self.model.to(self.device)
assert isinstance(self.model, torch.nn.Module)
else:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
@ -184,7 +203,8 @@ class sbb_predict:
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
assert isinstance(self.model, Model)
assert isinstance(self.model, Model)
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization":
@ -235,10 +255,9 @@ class sbb_predict:
return added_image, layout_only
def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(image_dir, 0)
img_1ch =cv2.imread(image_dir, 0)
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
@ -273,6 +292,15 @@ class sbb_predict:
pred_texts = decode_batch_predictions(preds, num_to_char)
pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts
elif self.task == "trocr":
from PIL import Image
image = Image.open(image_dir).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values.to(self.device))
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.task == 'reading_order':
@ -607,6 +635,8 @@ class sbb_predict:
cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}")
elif self.task == "trocr":
print(f"Detected text: {res}")
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save:
@ -710,10 +740,14 @@ class sbb_predict:
)
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
try:
with open(os.path.join(model,'config_eynollah.json')) as f:
config_params_model = json.load(f)
except:
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = config_params_model['task']
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr":
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr":
if image and not save:
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
sys.exit(1)

View file

@ -9,8 +9,8 @@ else:
import importlib.resources as importlib_resources
def get_font():
def get_font(font_size):
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40)
return ImageFont.truetype(font=font, size=font_size)