trocr: use beam search instead of greedy decoding

This commit is contained in:
Robert Sachunsky 2026-05-21 17:52:27 +02:00
parent 074753a98e
commit ea41dcae1d

View file

@ -90,7 +90,6 @@ class Eynollah_ocr(Eynollah):
page_ns,
tr_ocr_input_height_and_width,
) -> EynollahOcrResult:
import torch
total_bb_coordinates = []
cropped_lines = []
@ -148,10 +147,16 @@ class Eynollah_ocr(Eynollah):
imgs, return_tensors="pt").pixel_values
output = self.model_zoo.get('ocr').generate(
pixel_values.to(self.device),
# beam search instead of greedy decoding:
num_beams=4,
# also return probability
output_scores=True,
return_dict_in_generate=True)
conf = torch.max(torch.softmax(torch.cat(
output.scores, dim=0), dim=1), dim=1).values.tolist()
if output.sequences_scores is not None:
# log-prob averaged over length
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
else:
conf = [1.0] * len(output.sequences)
text = self.model_zoo.get('trocr_processor').batch_decode(
output.sequences,
skip_special_tokens=True,