trocr inference is integrated - works on CPU cause seg fault on GPU

This commit is contained in:
vahidrezanezhad 2026-02-18 15:04:54 +01:00
parent 733462381c
commit b426f7f152
5 changed files with 52 additions and 17 deletions

View file

@ -658,7 +658,7 @@ class Eynollah_ocr:
if out_image_with_text: if out_image_with_text:
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
draw = ImageDraw.Draw(image_text) draw = ImageDraw.Draw(image_text)
font = get_font() font = get_font(font_size=40)
for indexer_text, bb_ind in enumerate(total_bb_coordinates): for indexer_text, bb_ind in enumerate(total_bb_coordinates):
x_bb = bb_ind[0] x_bb = bb_ind[0]

View file

@ -6,6 +6,7 @@ from pathlib import Path
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import cv2 import cv2
import numpy as np import numpy as np
from eynollah.utils.font import get_font
from eynollah.training.gt_gen_utils import ( from eynollah.training.gt_gen_utils import (
filter_contours_area_of_image, filter_contours_area_of_image,
@ -514,8 +515,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
else: else:
xml_files_ind = [xml_file] xml_files_ind = [xml_file]
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! ###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = ImageFont.truetype(font_path, 40) font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
for ind_xml in tqdm(xml_files_ind): for ind_xml in tqdm(xml_files_ind):
indexer = 0 indexer = 0
@ -552,11 +553,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
is_vertical = h > 2*w # Check orientation is_vertical = h > 2*w # Check orientation
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
if is_vertical: if is_vertical:
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
text_draw = ImageDraw.Draw(text_img) text_draw = ImageDraw.Draw(text_img)

View file

@ -7,7 +7,7 @@ import cv2
from shapely import geometry from shapely import geometry
from pathlib import Path from pathlib import Path
from PIL import ImageFont from PIL import ImageFont
from eynollah.utils.font import get_font
KERNEL = np.ones((5, 5), np.uint8) KERNEL = np.ones((5, 5), np.uint8)
@ -350,11 +350,11 @@ def get_textline_contours_and_ocr_text(xml_file):
ocr_textlines.append(ocr_text_in[0]) ocr_textlines.append(ocr_text_in[0])
return co_use_case, y_len, x_len, ocr_textlines return co_use_case, y_len, x_len, ocr_textlines
def fit_text_single_line(draw, text, font_path, max_width, max_height): def fit_text_single_line(draw, text, max_width, max_height):
initial_font_size = 50 initial_font_size = 50
font_size = initial_font_size font_size = initial_font_size
while font_size > 10: # Minimum font size while font_size > 10: # Minimum font size
font = ImageFont.truetype(font_path, font_size) font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
text_width = text_bbox[2] - text_bbox[0] text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1] text_height = text_bbox[3] - text_bbox[1]
@ -364,7 +364,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
font_size -= 2 # Reduce font size and retry font_size -= 2 # Reduce font size and retry
return ImageFont.truetype(font_path, 10) # Smallest font fallback return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
def get_layout_contours_for_visualization(xml_file): def get_layout_contours_for_visualization(xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))

View file

@ -170,6 +170,25 @@ class sbb_predict:
self.model = tf.keras.models.Model( self.model = tf.keras.models.Model(
self.model.get_layer(name = "image").input, self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output) self.model.get_layer(name = "dense2").output)
assert isinstance(self.model, Model)
elif self.task == "trocr":
import torch
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
if self.cpu:
self.device = torch.device('cpu')
else:
self.device = torch.device('cuda:0')
self.model.to(self.device)
assert isinstance(self.model, torch.nn.Module)
else: else:
config = tf.compat.v1.ConfigProto() config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
@ -184,7 +203,8 @@ class sbb_predict:
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
assert isinstance(self.model, Model)
assert isinstance(self.model, Model)
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization": if task == "binarization":
@ -235,10 +255,9 @@ class sbb_predict:
return added_image, layout_only return added_image, layout_only
def predict(self, image_dir): def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification': if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name'] classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(image_dir, 0) img_1ch =cv2.imread(image_dir, 0)
img_1ch = img_1ch / 255.0 img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
@ -273,6 +292,15 @@ class sbb_predict:
pred_texts = decode_batch_predictions(preds, num_to_char) pred_texts = decode_batch_predictions(preds, num_to_char)
pred_texts = pred_texts[0].replace("[UNK]", "") pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts return pred_texts
elif self.task == "trocr":
from PIL import Image
image = Image.open(image_dir).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values.to(self.device))
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif self.task == 'reading_order': elif self.task == 'reading_order':
@ -607,6 +635,8 @@ class sbb_predict:
cv2.imwrite(self.save,res) cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr": elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}") print(f"Detected text: {res}")
elif self.task == "trocr":
print(f"Detected text: {res}")
else: else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save: if self.save:
@ -710,10 +740,14 @@ class sbb_predict:
) )
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di is required" assert image or dir_in, "Either a single image -i or a dir_in -di is required"
with open(os.path.join(model,'config.json')) as f: try:
config_params_model = json.load(f) with open(os.path.join(model,'config_eynollah.json')) as f:
config_params_model = json.load(f)
except:
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = config_params_model['task'] task = config_params_model['task']
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr": if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr":
if image and not save: if image and not save:
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
sys.exit(1) sys.exit(1)

View file

@ -9,8 +9,8 @@ else:
import importlib.resources as importlib_resources import importlib.resources as importlib_resources
def get_font(): def get_font(font_size):
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
with importlib_resources.as_file(font) as font: with importlib_resources.as_file(font) as font:
return ImageFont.truetype(font=font, size=40) return ImageFont.truetype(font=font, size=font_size)