diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index f1b155b..faeb042 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -90,12 +90,14 @@ class Eynollah_ocr(Eynollah): page_ns, tr_ocr_input_height_and_width, ) -> EynollahOcrResult: + import torch total_bb_coordinates = [] cropped_lines = [] cropped_lines_region_indexer = [] cropped_lines_meging_indexing = [] extracted_texts = [] + extracted_confs = [] for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)): for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)): @@ -142,15 +144,20 @@ class Eynollah_ocr(Eynollah): self.logger.debug("processing %d lines for %d regions", len(cropped_lines), len(set(cropped_lines_region_indexer))) for imgs in batched(cropped_lines, self.b_s): - pixel_values_merged = self.model_zoo.get('trocr_processor')( + pixel_values = self.model_zoo.get('trocr_processor')( imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, + output = self.model_zoo.get('ocr').generate( + pixel_values.to(self.device), + output_scores=True, + return_dict_in_generate=True) + conf = torch.max(torch.softmax(torch.cat( + output.scores, dim=0), dim=1), dim=1).values.tolist() + text = self.model_zoo.get('trocr_processor').batch_decode( + output.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=False) - extracted_texts = extracted_texts + generated_text_merged + extracted_confs.extend(conf) + extracted_texts.extend(text) del cropped_lines gc.collect() @@ -159,10 +166,15 @@ class Eynollah_ocr(Eynollah): else extracted_texts[ind] + " " + extracted_texts[ind + 1] for ind in range(len(cropped_lines_meging_indexing)) if cropped_lines_meging_indexing[ind] >= 0] + extracted_confs_merged = [extracted_confs[ind] + if cropped_lines_meging_indexing[ind] == 0 + else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1]) + for ind in range(len(cropped_lines_meging_indexing)) + if cropped_lines_meging_indexing[ind] >= 0] return EynollahOcrResult( extracted_texts_merged=extracted_texts_merged, - extracted_conf_value_merged=None, + extracted_conf_value_merged=extracted_confs_merged, cropped_lines_region_indexer=cropped_lines_region_indexer, total_bb_coordinates=total_bb_coordinates, ) @@ -618,6 +630,7 @@ class Eynollah_ocr(Eynollah): has_textline = False for child_textregion in nn: + # FIXME: should remove Word level, if it already exists if child_textregion.tag.endswith("TextLine"): is_textline_text = False