trocr: extract confidence, too

This commit is contained in:
Robert Sachunsky 2026-05-21 17:25:39 +02:00
parent f3649adbf2
commit 000e4ac8d8

View file

@ -90,12 +90,14 @@ class Eynollah_ocr(Eynollah):
page_ns, page_ns,
tr_ocr_input_height_and_width, tr_ocr_input_height_and_width,
) -> EynollahOcrResult: ) -> EynollahOcrResult:
import torch
total_bb_coordinates = [] total_bb_coordinates = []
cropped_lines = [] cropped_lines = []
cropped_lines_region_indexer = [] cropped_lines_region_indexer = []
cropped_lines_meging_indexing = [] cropped_lines_meging_indexing = []
extracted_texts = [] extracted_texts = []
extracted_confs = []
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)): for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)): for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
@ -142,15 +144,20 @@ class Eynollah_ocr(Eynollah):
self.logger.debug("processing %d lines for %d regions", self.logger.debug("processing %d lines for %d regions",
len(cropped_lines), len(set(cropped_lines_region_indexer))) len(cropped_lines), len(set(cropped_lines_region_indexer)))
for imgs in batched(cropped_lines, self.b_s): for imgs in batched(cropped_lines, self.b_s):
pixel_values_merged = self.model_zoo.get('trocr_processor')( pixel_values = self.model_zoo.get('trocr_processor')(
imgs, return_tensors="pt").pixel_values imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_zoo.get('ocr').generate( output = self.model_zoo.get('ocr').generate(
pixel_values_merged.to(self.device)) pixel_values.to(self.device),
generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( output_scores=True,
generated_ids_merged, return_dict_in_generate=True)
conf = torch.max(torch.softmax(torch.cat(
output.scores, dim=0), dim=1), dim=1).values.tolist()
text = self.model_zoo.get('trocr_processor').batch_decode(
output.sequences,
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False) clean_up_tokenization_spaces=False)
extracted_texts = extracted_texts + generated_text_merged extracted_confs.extend(conf)
extracted_texts.extend(text)
del cropped_lines del cropped_lines
gc.collect() gc.collect()
@ -159,10 +166,15 @@ class Eynollah_ocr(Eynollah):
else extracted_texts[ind] + " " + extracted_texts[ind + 1] else extracted_texts[ind] + " " + extracted_texts[ind + 1]
for ind in range(len(cropped_lines_meging_indexing)) for ind in range(len(cropped_lines_meging_indexing))
if cropped_lines_meging_indexing[ind] >= 0] if cropped_lines_meging_indexing[ind] >= 0]
extracted_confs_merged = [extracted_confs[ind]
if cropped_lines_meging_indexing[ind] == 0
else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1])
for ind in range(len(cropped_lines_meging_indexing))
if cropped_lines_meging_indexing[ind] >= 0]
return EynollahOcrResult( return EynollahOcrResult(
extracted_texts_merged=extracted_texts_merged, extracted_texts_merged=extracted_texts_merged,
extracted_conf_value_merged=None, extracted_conf_value_merged=extracted_confs_merged,
cropped_lines_region_indexer=cropped_lines_region_indexer, cropped_lines_region_indexer=cropped_lines_region_indexer,
total_bb_coordinates=total_bb_coordinates, total_bb_coordinates=total_bb_coordinates,
) )
@ -618,6 +630,7 @@ class Eynollah_ocr(Eynollah):
has_textline = False has_textline = False
for child_textregion in nn: for child_textregion in nn:
# FIXME: should remove Word level, if it already exists
if child_textregion.tag.endswith("TextLine"): if child_textregion.tag.endswith("TextLine"):
is_textline_text = False is_textline_text = False