mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
trocr: use beam search instead of greedy decoding
This commit is contained in:
parent
074753a98e
commit
ea41dcae1d
1 changed files with 8 additions and 3 deletions
|
|
@ -90,7 +90,6 @@ class Eynollah_ocr(Eynollah):
|
||||||
page_ns,
|
page_ns,
|
||||||
tr_ocr_input_height_and_width,
|
tr_ocr_input_height_and_width,
|
||||||
) -> EynollahOcrResult:
|
) -> EynollahOcrResult:
|
||||||
import torch
|
|
||||||
|
|
||||||
total_bb_coordinates = []
|
total_bb_coordinates = []
|
||||||
cropped_lines = []
|
cropped_lines = []
|
||||||
|
|
@ -148,10 +147,16 @@ class Eynollah_ocr(Eynollah):
|
||||||
imgs, return_tensors="pt").pixel_values
|
imgs, return_tensors="pt").pixel_values
|
||||||
output = self.model_zoo.get('ocr').generate(
|
output = self.model_zoo.get('ocr').generate(
|
||||||
pixel_values.to(self.device),
|
pixel_values.to(self.device),
|
||||||
|
# beam search instead of greedy decoding:
|
||||||
|
num_beams=4,
|
||||||
|
# also return probability
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
return_dict_in_generate=True)
|
return_dict_in_generate=True)
|
||||||
conf = torch.max(torch.softmax(torch.cat(
|
if output.sequences_scores is not None:
|
||||||
output.scores, dim=0), dim=1), dim=1).values.tolist()
|
# log-prob averaged over length
|
||||||
|
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
|
||||||
|
else:
|
||||||
|
conf = [1.0] * len(output.sequences)
|
||||||
text = self.model_zoo.get('trocr_processor').batch_decode(
|
text = self.model_zoo.get('trocr_processor').batch_decode(
|
||||||
output.sequences,
|
output.sequences,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue