mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
same task name for transformer-ocr training and inference
This commit is contained in:
parent
a84ae67e7a
commit
c4434c7f7d
1 changed files with 4 additions and 4 deletions
|
|
@ -173,7 +173,7 @@ class sbb_predict:
|
|||
|
||||
assert isinstance(self.model, Model)
|
||||
|
||||
elif self.task == "trocr":
|
||||
elif self.task == "transformer-ocr":
|
||||
import torch
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
from transformers import TrOCRProcessor
|
||||
|
|
@ -293,7 +293,7 @@ class sbb_predict:
|
|||
pred_texts = pred_texts[0].replace("[UNK]", "")
|
||||
return pred_texts
|
||||
|
||||
elif self.task == "trocr":
|
||||
elif self.task == "transformer-ocr":
|
||||
from PIL import Image
|
||||
image = Image.open(image_dir).convert("RGB")
|
||||
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
||||
|
|
@ -635,7 +635,7 @@ class sbb_predict:
|
|||
cv2.imwrite(self.save,res)
|
||||
elif self.task == "cnn-rnn-ocr":
|
||||
print(f"Detected text: {res}")
|
||||
elif self.task == "trocr":
|
||||
elif self.task == "transformer-ocr":
|
||||
print(f"Detected text: {res}")
|
||||
else:
|
||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||
|
|
@ -745,7 +745,7 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil
|
|||
config_params_model = json.load(f)
|
||||
|
||||
task = config_params_model['task']
|
||||
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "trocr":
|
||||
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr" and task != "transformer-ocr":
|
||||
if image and not save:
|
||||
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue