mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-26 07:39:22 +02:00
trocr: extract confidence, too
This commit is contained in:
parent
f3649adbf2
commit
000e4ac8d8
1 changed files with 20 additions and 7 deletions
|
|
@ -90,12 +90,14 @@ class Eynollah_ocr(Eynollah):
|
|||
page_ns,
|
||||
tr_ocr_input_height_and_width,
|
||||
) -> EynollahOcrResult:
|
||||
import torch
|
||||
|
||||
total_bb_coordinates = []
|
||||
cropped_lines = []
|
||||
cropped_lines_region_indexer = []
|
||||
cropped_lines_meging_indexing = []
|
||||
extracted_texts = []
|
||||
extracted_confs = []
|
||||
|
||||
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||
|
|
@ -142,15 +144,20 @@ class Eynollah_ocr(Eynollah):
|
|||
self.logger.debug("processing %d lines for %d regions",
|
||||
len(cropped_lines), len(set(cropped_lines_region_indexer)))
|
||||
for imgs in batched(cropped_lines, self.b_s):
|
||||
pixel_values_merged = self.model_zoo.get('trocr_processor')(
|
||||
pixel_values = self.model_zoo.get('trocr_processor')(
|
||||
imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_zoo.get('ocr').generate(
|
||||
pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
generated_ids_merged,
|
||||
output = self.model_zoo.get('ocr').generate(
|
||||
pixel_values.to(self.device),
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True)
|
||||
conf = torch.max(torch.softmax(torch.cat(
|
||||
output.scores, dim=0), dim=1), dim=1).values.tolist()
|
||||
text = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
output.sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
extracted_confs.extend(conf)
|
||||
extracted_texts.extend(text)
|
||||
del cropped_lines
|
||||
gc.collect()
|
||||
|
||||
|
|
@ -159,10 +166,15 @@ class Eynollah_ocr(Eynollah):
|
|||
else extracted_texts[ind] + " " + extracted_texts[ind + 1]
|
||||
for ind in range(len(cropped_lines_meging_indexing))
|
||||
if cropped_lines_meging_indexing[ind] >= 0]
|
||||
extracted_confs_merged = [extracted_confs[ind]
|
||||
if cropped_lines_meging_indexing[ind] == 0
|
||||
else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1])
|
||||
for ind in range(len(cropped_lines_meging_indexing))
|
||||
if cropped_lines_meging_indexing[ind] >= 0]
|
||||
|
||||
return EynollahOcrResult(
|
||||
extracted_texts_merged=extracted_texts_merged,
|
||||
extracted_conf_value_merged=None,
|
||||
extracted_conf_value_merged=extracted_confs_merged,
|
||||
cropped_lines_region_indexer=cropped_lines_region_indexer,
|
||||
total_bb_coordinates=total_bb_coordinates,
|
||||
)
|
||||
|
|
@ -618,6 +630,7 @@ class Eynollah_ocr(Eynollah):
|
|||
|
||||
has_textline = False
|
||||
for child_textregion in nn:
|
||||
# FIXME: should remove Word level, if it already exists
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
|
||||
is_textline_text = False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue