diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index c613fe2..dc7979a 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -173,7 +173,7 @@ class sbb_predict: assert isinstance(self.model, Model) - elif self.task == "trocr": + elif self.task == "transformer-ocr": import torch from transformers import VisionEncoderDecoderModel from transformers import TrOCRProcessor @@ -293,7 +293,7 @@ class sbb_predict: pred_texts = pred_texts[0].replace("[UNK]", "") return pred_texts - elif self.task == "trocr": + elif self.task == "transformer-ocr": from PIL import Image image = Image.open(image_dir).convert("RGB") pixel_values = self.processor(image, return_tensors="pt").pixel_values @@ -635,7 +635,7 @@ class sbb_predict: cv2.imwrite(self.save,res) elif self.task == "cnn-rnn-ocr": print(f"Detected text: {res}") - elif self.task == "trocr": + elif self.task == "transformer-ocr": print(f"Detected text: {res}") else: img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) @@ -745,7 +745,7 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil config_params_model = json.load(f) task = config_params_model['task'] - if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr": + if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "transformer-ocr": if image and not save: print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") sys.exit(1)