diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 77ad98f..4371453 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -139,11 +139,14 @@ class Eynollah_ocr(Eynollah): cropped_lines = [] indexer_b_s = 0 - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + pixel_values_merged = self.model_zoo.get('trocr_processor')( + imgs, return_tensors="pt").pixel_values generated_ids_merged = self.model_zoo.get('ocr').generate( pixel_values_merged.to(self.device)) generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) + generated_ids_merged, + skip_special_tokens=True, + clean_up_tokenization_spaces=False) extracted_texts = extracted_texts + generated_text_merged @@ -162,11 +165,14 @@ class Eynollah_ocr(Eynollah): cropped_lines = [] indexer_b_s = 0 - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + pixel_values_merged = self.model_zoo.get('trocr_processor')( + imgs, return_tensors="pt").pixel_values generated_ids_merged = self.model_zoo.get('ocr').generate( pixel_values_merged.to(self.device)) generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) + generated_ids_merged, + skip_special_tokens=True, + clean_up_tokenization_spaces=False) extracted_texts = extracted_texts + generated_text_merged @@ -182,11 +188,14 @@ class Eynollah_ocr(Eynollah): cropped_lines = [] indexer_b_s = 0 - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + pixel_values_merged = self.model_zoo.get('trocr_processor')( + imgs, return_tensors="pt").pixel_values generated_ids_merged = self.model_zoo.get('ocr').generate( pixel_values_merged.to(self.device)) generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) + generated_ids_merged, + skip_special_tokens=True, + clean_up_tokenization_spaces=False) extracted_texts = extracted_texts + generated_text_merged @@ -194,22 +203,23 @@ class Eynollah_ocr(Eynollah): cropped_lines.append(img_crop) cropped_lines_meging_indexing.append(0) indexer_b_s+=1 - + if indexer_b_s==self.b_s: imgs = cropped_lines[:] cropped_lines = [] indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values + + pixel_values_merged = self.model_zoo.get('trocr_processor')( + imgs, return_tensors="pt").pixel_values generated_ids_merged = self.model_zoo.get('ocr').generate( pixel_values_merged.to(self.device)) generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - + generated_ids_merged, + skip_special_tokens=True, + clean_up_tokenization_spaces=False) + extracted_texts = extracted_texts + generated_text_merged - - - + indexer_text_region = indexer_text_region +1 if indexer_b_s!=0: @@ -217,9 +227,14 @@ class Eynollah_ocr(Eynollah): cropped_lines = [] indexer_b_s = 0 - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True) + pixel_values_merged = self.model_zoo.get('trocr_processor')( + imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_zoo.get('ocr').generate( + pixel_values_merged.to(self.device)) + generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( + generated_ids_merged, + skip_special_tokens=True, + clean_up_tokenization_spaces=False) extracted_texts = extracted_texts + generated_text_merged @@ -750,6 +765,7 @@ class Eynollah_ocr(Eynollah): indexer_textregion = indexer_textregion + 1 ET.register_namespace("",page_ns) + self.logger.info("output filename: '%s'", out_file_ocr) page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None) def run(