mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
trocr inference is integrated - works on CPU cause seg fault on GPU
This commit is contained in:
parent
a11c833fc1
commit
499e3d0715
5 changed files with 39 additions and 13 deletions
|
|
@ -345,7 +345,7 @@ class Eynollah_ocr(Eynollah):
|
||||||
if out_image_with_text:
|
if out_image_with_text:
|
||||||
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
||||||
draw = ImageDraw.Draw(image_text)
|
draw = ImageDraw.Draw(image_text)
|
||||||
font = get_font()
|
font = get_font(font_size=40)
|
||||||
|
|
||||||
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
|
for indexer_text, bb_ind in enumerate(total_bb_coordinates):
|
||||||
x_bb = bb_ind[0]
|
x_bb = bb_ind[0]
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from eynollah.utils.font import get_font
|
||||||
|
|
||||||
from .gt_gen_utils import (
|
from .gt_gen_utils import (
|
||||||
filter_contours_area_of_image,
|
filter_contours_area_of_image,
|
||||||
|
|
@ -552,8 +553,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
||||||
else:
|
else:
|
||||||
xml_files_ind = [xml_file]
|
xml_files_ind = [xml_file]
|
||||||
|
|
||||||
font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||||
font = ImageFont.truetype(font_path, 40)
|
font = get_font(font_size=40)#ImageFont.truetype(font_path, 40)
|
||||||
|
|
||||||
for ind_xml in tqdm(xml_files_ind):
|
for ind_xml in tqdm(xml_files_ind):
|
||||||
indexer = 0
|
indexer = 0
|
||||||
|
|
@ -590,11 +591,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
||||||
|
|
||||||
|
|
||||||
is_vertical = h > 2*w # Check orientation
|
is_vertical = h > 2*w # Check orientation
|
||||||
font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) )
|
font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) )
|
||||||
|
|
||||||
if is_vertical:
|
if is_vertical:
|
||||||
|
|
||||||
vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8))
|
vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8))
|
||||||
|
|
||||||
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
|
text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped
|
||||||
text_draw = ImageDraw.Draw(text_img)
|
text_draw = ImageDraw.Draw(text_img)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from shapely import geometry
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from PIL import ImageFont
|
from PIL import ImageFont
|
||||||
from ocrd_utils import bbox_from_points
|
from ocrd_utils import bbox_from_points
|
||||||
|
from eynollah.utils.font import get_font
|
||||||
|
|
||||||
KERNEL = np.ones((5, 5), np.uint8)
|
KERNEL = np.ones((5, 5), np.uint8)
|
||||||
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
|
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
|
||||||
|
|
@ -352,11 +352,11 @@ def get_textline_contours_and_ocr_text(xml_file):
|
||||||
ocr_textlines.append(ocr_text_in[0])
|
ocr_textlines.append(ocr_text_in[0])
|
||||||
return co_use_case, y_len, x_len, ocr_textlines
|
return co_use_case, y_len, x_len, ocr_textlines
|
||||||
|
|
||||||
def fit_text_single_line(draw, text, font_path, max_width, max_height):
|
def fit_text_single_line(draw, text, max_width, max_height):
|
||||||
initial_font_size = 50
|
initial_font_size = 50
|
||||||
font_size = initial_font_size
|
font_size = initial_font_size
|
||||||
while font_size > 10: # Minimum font size
|
while font_size > 10: # Minimum font size
|
||||||
font = ImageFont.truetype(font_path, font_size)
|
font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size)
|
||||||
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
|
text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box
|
||||||
text_width = text_bbox[2] - text_bbox[0]
|
text_width = text_bbox[2] - text_bbox[0]
|
||||||
text_height = text_bbox[3] - text_bbox[1]
|
text_height = text_bbox[3] - text_bbox[1]
|
||||||
|
|
@ -366,7 +366,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height):
|
||||||
|
|
||||||
font_size -= 2 # Reduce font size and retry
|
font_size -= 2 # Reduce font size and retry
|
||||||
|
|
||||||
return ImageFont.truetype(font_path, 10) # Smallest font fallback
|
return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback
|
||||||
|
|
||||||
def get_layout_contours_for_visualization(xml_file):
|
def get_layout_contours_for_visualization(xml_file):
|
||||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||||
|
|
|
||||||
|
|
@ -132,15 +132,31 @@ class SBBPredict:
|
||||||
self.model = Model(
|
self.model = Model(
|
||||||
self.model.get_layer(name = "image").input,
|
self.model.get_layer(name = "image").input,
|
||||||
self.model.get_layer(name = "dense2").output)
|
self.model.get_layer(name = "dense2").output)
|
||||||
|
assert isinstance(self.model, Model)
|
||||||
|
elif self.task == "transformer-ocr":
|
||||||
|
import torch
|
||||||
|
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
|
||||||
|
|
||||||
|
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir)
|
||||||
|
self.processor = TrOCRProcessor.from_pretrained(self.model_dir)
|
||||||
|
|
||||||
|
if self.cpu:
|
||||||
|
self.device = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
self.device = torch.device('cuda:0')
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
assert isinstance(self.model, torch.nn.Module)
|
||||||
else:
|
else:
|
||||||
self.model = load_model(self.model_dir, compile=False,
|
self.model = load_model(self.model_dir, compile=False,
|
||||||
custom_objects={"PatchEncoder": PatchEncoder,
|
custom_objects={"PatchEncoder": PatchEncoder,
|
||||||
"Patches": Patches})
|
"Patches": Patches})
|
||||||
|
assert isinstance(self.model, Model)
|
||||||
|
|
||||||
##if self.weights_dir!=None:
|
##if self.weights_dir!=None:
|
||||||
##self.model.load_weights(self.weights_dir)
|
##self.model.load_weights(self.weights_dir)
|
||||||
|
|
||||||
assert isinstance(self.model, Model)
|
|
||||||
if self.task != 'classification' and self.task != 'reading_order':
|
if self.task != 'classification' and self.task != 'reading_order':
|
||||||
last = self.model.layers[-1]
|
last = self.model.layers[-1]
|
||||||
self.img_height = last.output_shape[1]
|
self.img_height = last.output_shape[1]
|
||||||
|
|
@ -230,6 +246,13 @@ class SBBPredict:
|
||||||
pred_texts = decode_batch_predictions(preds, num_to_char)
|
pred_texts = decode_batch_predictions(preds, num_to_char)
|
||||||
pred_texts = pred_texts[0].replace("[UNK]", "")
|
pred_texts = pred_texts[0].replace("[UNK]", "")
|
||||||
return pred_texts
|
return pred_texts
|
||||||
|
|
||||||
|
elif self.task == "transformer-ocr":
|
||||||
|
from PIL import Image
|
||||||
|
image = Image.open(image_dir).convert("RGB")
|
||||||
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
||||||
|
generated_ids = self.model.generate(pixel_values.to(self.device))
|
||||||
|
return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
|
|
||||||
elif self.task == 'reading_order':
|
elif self.task == 'reading_order':
|
||||||
|
|
@ -566,6 +589,8 @@ class SBBPredict:
|
||||||
cv2.imwrite(self.save,res)
|
cv2.imwrite(self.save,res)
|
||||||
elif self.task == "cnn-rnn-ocr":
|
elif self.task == "cnn-rnn-ocr":
|
||||||
print(f"Detected text: {res}")
|
print(f"Detected text: {res}")
|
||||||
|
elif self.task == "transformer-ocr":
|
||||||
|
print(f"Detected text: {res}")
|
||||||
else:
|
else:
|
||||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
if self.save:
|
if self.save:
|
||||||
|
|
@ -672,7 +697,7 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil
|
||||||
with open(os.path.join(model,'config.json')) as f:
|
with open(os.path.join(model,'config.json')) as f:
|
||||||
config_params_model = json.load(f)
|
config_params_model = json.load(f)
|
||||||
task = config_params_model['task']
|
task = config_params_model['task']
|
||||||
if task not in ['classification', 'reading_order', "cnn-rnn-ocr"]:
|
if task not in ['classification', 'reading_order', "cnn-rnn-ocr", "transformer-ocr"]:
|
||||||
assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
|
assert not image or save, "For segmentation or binarization, an input single image -i also requires an output filename -s"
|
||||||
assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
|
assert not dir_in or out, "For segmentation or binarization, an input directory -di also requires an output directory -o"
|
||||||
x = SBBPredict(image, dir_in, model, task, config_params_model,
|
x = SBBPredict(image, dir_in, model, task, config_params_model,
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ else:
|
||||||
import importlib.resources as importlib_resources
|
import importlib.resources as importlib_resources
|
||||||
|
|
||||||
|
|
||||||
def get_font():
|
def get_font(font_size):
|
||||||
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
#font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists!
|
||||||
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
|
font = importlib_resources.files(__package__) / "../Charis-Regular.ttf"
|
||||||
with importlib_resources.as_file(font) as font:
|
with importlib_resources.as_file(font) as font:
|
||||||
return ImageFont.truetype(font=font, size=40)
|
return ImageFont.truetype(font=font, size=font_size)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue