trocr: use beam search instead of greedy decoding

This commit is contained in:
Robert Sachunsky 2026-05-21 17:52:27 +02:00
parent 074753a98e
commit ea41dcae1d

View file

@ -90,7 +90,6 @@ class Eynollah_ocr(Eynollah):
page_ns, page_ns,
tr_ocr_input_height_and_width, tr_ocr_input_height_and_width,
) -> EynollahOcrResult: ) -> EynollahOcrResult:
import torch
total_bb_coordinates = [] total_bb_coordinates = []
cropped_lines = [] cropped_lines = []
@ -148,10 +147,16 @@ class Eynollah_ocr(Eynollah):
imgs, return_tensors="pt").pixel_values imgs, return_tensors="pt").pixel_values
output = self.model_zoo.get('ocr').generate( output = self.model_zoo.get('ocr').generate(
pixel_values.to(self.device), pixel_values.to(self.device),
# beam search instead of greedy decoding:
num_beams=4,
# also return probability
output_scores=True, output_scores=True,
return_dict_in_generate=True) return_dict_in_generate=True)
conf = torch.max(torch.softmax(torch.cat( if output.sequences_scores is not None:
output.scores, dim=0), dim=1), dim=1).values.tolist() # log-prob averaged over length
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
else:
conf = [1.0] * len(output.sequences)
text = self.model_zoo.get('trocr_processor').batch_decode( text = self.model_zoo.get('trocr_processor').batch_decode(
output.sequences, output.sequences,
skip_special_tokens=True, skip_special_tokens=True,