diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index b5228ff..173ba46 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -658,7 +658,7 @@ class Eynollah_ocr: if out_image_with_text: image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white") draw = ImageDraw.Draw(image_text) - font = get_font() + font = get_font(font_size=40) for indexer_text, bb_ind in enumerate(total_bb_coordinates): x_bb = bb_ind[0] diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index 1c330f1..28f3f1c 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -6,6 +6,7 @@ from pathlib import Path from PIL import Image, ImageDraw, ImageFont import cv2 import numpy as np +from eynollah.utils.font import get_font from eynollah.training.gt_gen_utils import ( filter_contours_area_of_image, @@ -514,8 +515,8 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): else: xml_files_ind = [xml_file] - font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! - font = ImageFont.truetype(font_path, 40) + ###font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! + font = get_font(font_size=40)#ImageFont.truetype(font_path, 40) for ind_xml in tqdm(xml_files_ind): indexer = 0 @@ -552,11 +553,11 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): is_vertical = h > 2*w # Check orientation - font = fit_text_single_line(draw, ocr_texts[index], font_path, w, int(h*0.4) ) + font = fit_text_single_line(draw, ocr_texts[index], w, int(h*0.4) ) if is_vertical: - vertical_font = fit_text_single_line(draw, ocr_texts[index], font_path, h, int(w * 0.8)) + vertical_font = fit_text_single_line(draw, ocr_texts[index], h, int(w * 0.8)) text_img = Image.new("RGBA", (h, w), (255, 255, 255, 0)) # Note: dimensions are swapped text_draw = ImageDraw.Draw(text_img) diff --git a/src/eynollah/training/gt_gen_utils.py b/src/eynollah/training/gt_gen_utils.py index 50bb6fa..1e15b3b 100644 --- a/src/eynollah/training/gt_gen_utils.py +++ b/src/eynollah/training/gt_gen_utils.py @@ -7,7 +7,7 @@ import cv2 from shapely import geometry from pathlib import Path from PIL import ImageFont - +from eynollah.utils.font import get_font KERNEL = np.ones((5, 5), np.uint8) @@ -350,11 +350,11 @@ def get_textline_contours_and_ocr_text(xml_file): ocr_textlines.append(ocr_text_in[0]) return co_use_case, y_len, x_len, ocr_textlines -def fit_text_single_line(draw, text, font_path, max_width, max_height): +def fit_text_single_line(draw, text, max_width, max_height): initial_font_size = 50 font_size = initial_font_size while font_size > 10: # Minimum font size - font = ImageFont.truetype(font_path, font_size) + font = get_font(font_size=font_size)# ImageFont.truetype(font_path, font_size) text_bbox = draw.textbbox((0, 0), text, font=font) # Get text bounding box text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] @@ -364,7 +364,7 @@ def fit_text_single_line(draw, text, font_path, max_width, max_height): font_size -= 2 # Reduce font size and retry - return ImageFont.truetype(font_path, 10) # Smallest font fallback + return get_font(font_size=10)#ImageFont.truetype(font_path, 10) # Smallest font fallback def get_layout_contours_for_visualization(xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index c6ad186..19265bc 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -170,6 +170,25 @@ class sbb_predict: self.model = tf.keras.models.Model( self.model.get_layer(name = "image").input, self.model.get_layer(name = "dense2").output) + + assert isinstance(self.model, Model) + + elif self.task == "trocr": + import torch + from transformers import VisionEncoderDecoderModel + from transformers import TrOCRProcessor + + self.model = VisionEncoderDecoderModel.from_pretrained(self.model_dir) + self.processor = TrOCRProcessor.from_pretrained(self.model_dir) + + if self.cpu: + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda:0') + + self.model.to(self.device) + + assert isinstance(self.model, torch.nn.Module) else: config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True @@ -184,7 +203,8 @@ class sbb_predict: self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] - assert isinstance(self.model, Model) + + assert isinstance(self.model, Model) def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: if task == "binarization": @@ -235,10 +255,9 @@ class sbb_predict: return added_image, layout_only def predict(self, image_dir): - assert isinstance(self.model, Model) if self.task == 'classification': classes_names = self.config_params_model['classification_classes_name'] - img_1ch = img=cv2.imread(image_dir, 0) + img_1ch =cv2.imread(image_dir, 0) img_1ch = img_1ch / 255.0 img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) @@ -273,6 +292,15 @@ class sbb_predict: pred_texts = decode_batch_predictions(preds, num_to_char) pred_texts = pred_texts[0].replace("[UNK]", "") return pred_texts + + elif self.task == "trocr": + from PIL import Image + image = Image.open(image_dir).convert("RGB") + pixel_values = self.processor(image, return_tensors="pt").pixel_values + generated_ids = self.model.generate(pixel_values.to(self.device)) + return self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + elif self.task == 'reading_order': @@ -607,6 +635,8 @@ class sbb_predict: cv2.imwrite(self.save,res) elif self.task == "cnn-rnn-ocr": print(f"Detected text: {res}") + elif self.task == "trocr": + print(f"Detected text: {res}") else: img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) if self.save: @@ -710,10 +740,14 @@ class sbb_predict: ) def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area): assert image or dir_in, "Either a single image -i or a dir_in -di is required" - with open(os.path.join(model,'config.json')) as f: - config_params_model = json.load(f) + try: + with open(os.path.join(model,'config_eynollah.json')) as f: + config_params_model = json.load(f) + except: + with open(os.path.join(model,'config.json')) as f: + config_params_model = json.load(f) task = config_params_model['task'] - if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr": + if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr": if image and not save: print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") sys.exit(1) diff --git a/src/eynollah/utils/font.py b/src/eynollah/utils/font.py index 939933e..0354317 100644 --- a/src/eynollah/utils/font.py +++ b/src/eynollah/utils/font.py @@ -9,8 +9,8 @@ else: import importlib.resources as importlib_resources -def get_font(): +def get_font(font_size): #font_path = "Charis-7.000/Charis-Regular.ttf" # Make sure this file exists! font = importlib_resources.files(__package__) / "../Charis-Regular.ttf" with importlib_resources.as_file(font) as font: - return ImageFont.truetype(font=font, size=40) + return ImageFont.truetype(font=font, size=font_size)