same task name for transformer-ocr training and inference

This commit is contained in:
vahidrezanezhad 2026-02-19 13:59:16 +01:00
parent a84ae67e7a
commit c4434c7f7d

View file

@ -173,7 +173,7 @@ class sbb_predict:
assert isinstance(self.model, Model)
elif self.task == "trocr":
elif self.task == "transformer-ocr":
import torch
from transformers import VisionEncoderDecoderModel
from transformers import TrOCRProcessor
@ -293,7 +293,7 @@ class sbb_predict:
pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts
elif self.task == "trocr":
elif self.task == "transformer-ocr":
from PIL import Image
image = Image.open(image_dir).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
@ -635,7 +635,7 @@ class sbb_predict:
cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}")
elif self.task == "trocr":
elif self.task == "transformer-ocr":
print(f"Detected text: {res}")
else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
@ -745,7 +745,7 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil
config_params_model = json.load(f)
task = config_params_model['task']
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr":
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "transformer-ocr":
if image and not save:
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
sys.exit(1)