diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index faeb042..b94853b 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -90,7 +90,6 @@ class Eynollah_ocr(Eynollah): page_ns, tr_ocr_input_height_and_width, ) -> EynollahOcrResult: - import torch total_bb_coordinates = [] cropped_lines = [] @@ -148,10 +147,16 @@ class Eynollah_ocr(Eynollah): imgs, return_tensors="pt").pixel_values output = self.model_zoo.get('ocr').generate( pixel_values.to(self.device), + # beam search instead of greedy decoding: + num_beams=4, + # also return probability output_scores=True, return_dict_in_generate=True) - conf = torch.max(torch.softmax(torch.cat( - output.scores, dim=0), dim=1), dim=1).values.tolist() + if output.sequences_scores is not None: + # log-prob averaged over length + conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist() + else: + conf = [1.0] * len(output.sequences) text = self.model_zoo.get('trocr_processor').batch_decode( output.sequences, skip_special_tokens=True,