mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
trocr: use beam search instead of greedy decoding
This commit is contained in:
parent
074753a98e
commit
ea41dcae1d
1 changed files with 8 additions and 3 deletions
|
|
@ -90,7 +90,6 @@ class Eynollah_ocr(Eynollah):
|
|||
page_ns,
|
||||
tr_ocr_input_height_and_width,
|
||||
) -> EynollahOcrResult:
|
||||
import torch
|
||||
|
||||
total_bb_coordinates = []
|
||||
cropped_lines = []
|
||||
|
|
@ -148,10 +147,16 @@ class Eynollah_ocr(Eynollah):
|
|||
imgs, return_tensors="pt").pixel_values
|
||||
output = self.model_zoo.get('ocr').generate(
|
||||
pixel_values.to(self.device),
|
||||
# beam search instead of greedy decoding:
|
||||
num_beams=4,
|
||||
# also return probability
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True)
|
||||
conf = torch.max(torch.softmax(torch.cat(
|
||||
output.scores, dim=0), dim=1), dim=1).values.tolist()
|
||||
if output.sequences_scores is not None:
|
||||
# log-prob averaged over length
|
||||
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
|
||||
else:
|
||||
conf = [1.0] * len(output.sequences)
|
||||
text = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
output.sequences,
|
||||
skip_special_tokens=True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue