mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
Eynollah_ocr: correctly handle min_conf, improve writer…
- `min_conf_value_of_textline_text`: apply by skipping lines below threshold (instead of writing empty text), and delete their TextEquiv (if existing) - `write_ocr()`: simplify, and ensure consistency between line and region level text correctly
This commit is contained in:
parent
8ffc4ed8d3
commit
d2f2a1e06b
1 changed files with 50 additions and 87 deletions
|
|
@ -41,7 +41,7 @@ from .utils.utils_ocr import (
|
|||
@dataclass
|
||||
class EynollahOcrResult:
|
||||
extracted_texts_merged: List
|
||||
extracted_confs_merged: Optional[List]
|
||||
extracted_confs_merged: List
|
||||
cropped_lines_region_indexer: List
|
||||
total_bb_coordinates:List
|
||||
|
||||
|
|
@ -156,10 +156,6 @@ class Eynollah_ocr(Eynollah):
|
|||
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
|
||||
else:
|
||||
conf = [1.0] * len(output.sequences)
|
||||
if conf < self.min_conf_value_of_textline_text:
|
||||
extracted_confs.extend(0)
|
||||
extracted_texts.extend("")
|
||||
continue
|
||||
text = self.model_zoo.get('trocr_processor').batch_decode(
|
||||
output.sequences,
|
||||
skip_special_tokens=True,
|
||||
|
|
@ -327,17 +323,13 @@ class Eynollah_ocr(Eynollah):
|
|||
def nooov(x):
|
||||
return x != b'[UNK]'
|
||||
for pred, prob in zip(preds, probs):
|
||||
if prob < self.min_conf_value_of_textline_text:
|
||||
extracted_texts.append("")
|
||||
extracted_confs.append(0)
|
||||
else:
|
||||
text = b''.join(
|
||||
filter(nooov,
|
||||
map(bytes,
|
||||
(filter(None, char)
|
||||
for char in pred.tolist())))).decode('utf-8')
|
||||
extracted_texts.append(text)
|
||||
extracted_confs.append(prob)
|
||||
text = b''.join(
|
||||
filter(nooov,
|
||||
map(bytes,
|
||||
(filter(None, char)
|
||||
for char in pred.tolist())))).decode('utf-8')
|
||||
extracted_texts.append(text)
|
||||
extracted_confs.append(prob)
|
||||
del cropped_lines_rgb
|
||||
del cropped_lines_bin
|
||||
gc.collect()
|
||||
|
|
@ -375,7 +367,6 @@ class Eynollah_ocr(Eynollah):
|
|||
extracted_texts_merged = result.extracted_texts_merged
|
||||
extracted_confs_merged = result.extracted_confs_merged
|
||||
|
||||
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
|
||||
if out_image_with_text:
|
||||
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
||||
draw = ImageDraw.Draw(image_text)
|
||||
|
|
@ -403,78 +394,50 @@ class Eynollah_ocr(Eynollah):
|
|||
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
|
||||
image_text.save(out_image_with_text)
|
||||
|
||||
text_by_textregion = []
|
||||
for ind in unique_cropped_lines_region_indexer:
|
||||
ind = np.array(cropped_lines_region_indexer)==ind
|
||||
extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
|
||||
if len(extracted_texts_merged_un)>1:
|
||||
text_by_textregion_ind = ""
|
||||
next_glue = ""
|
||||
for indt in range(len(extracted_texts_merged_un)):
|
||||
if (extracted_texts_merged_un[indt].endswith('⸗') or
|
||||
extracted_texts_merged_un[indt].endswith('-') or
|
||||
extracted_texts_merged_un[indt].endswith('¬')):
|
||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1]
|
||||
next_glue = ""
|
||||
else:
|
||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt]
|
||||
next_glue = " "
|
||||
text_by_textregion.append(text_by_textregion_ind)
|
||||
else:
|
||||
text_by_textregion.append(" ".join(extracted_texts_merged_un))
|
||||
cropped_lines_region_indexer = np.array(cropped_lines_region_indexer)
|
||||
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||
lines_indexer = np.flatnonzero(cropped_lines_region_indexer == n_region)
|
||||
if not len(lines_indexer):
|
||||
continue
|
||||
|
||||
indexer = 0
|
||||
indexer_textregion = 0
|
||||
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
|
||||
|
||||
is_textregion_text = False
|
||||
for childtest in nn:
|
||||
if childtest.tag.endswith("TextEquiv"):
|
||||
is_textregion_text = True
|
||||
|
||||
if not is_textregion_text:
|
||||
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
|
||||
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
|
||||
|
||||
|
||||
has_textline = False
|
||||
for child_textregion in nn:
|
||||
# FIXME: should remove Word level, if it already exists
|
||||
if child_textregion.tag.endswith("TextLine"):
|
||||
|
||||
is_textline_text = False
|
||||
for childtest2 in child_textregion:
|
||||
if childtest2.tag.endswith("TextEquiv"):
|
||||
is_textline_text = True
|
||||
|
||||
|
||||
if not is_textline_text:
|
||||
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
|
||||
if extracted_confs_merged:
|
||||
text_subelement.set('conf', f"{extracted_confs_merged[indexer]:.2f}")
|
||||
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
|
||||
unicode_textline.text = extracted_texts_merged[indexer]
|
||||
else:
|
||||
for childtest3 in child_textregion:
|
||||
if childtest3.tag.endswith("TextEquiv"):
|
||||
for child_uc in childtest3:
|
||||
if child_uc.tag.endswith("Unicode"):
|
||||
if extracted_confs_merged:
|
||||
childtest3.set('conf', f"{extracted_confs_merged[indexer]:.2f}")
|
||||
child_uc.text = extracted_texts_merged[indexer]
|
||||
|
||||
indexer = indexer + 1
|
||||
has_textline = True
|
||||
if has_textline:
|
||||
if is_textregion_text:
|
||||
for child4 in nn:
|
||||
if child4.tag.endswith("TextEquiv"):
|
||||
for childtr_uc in child4:
|
||||
if childtr_uc.tag.endswith("Unicode"):
|
||||
childtr_uc.text = text_by_textregion[indexer_textregion]
|
||||
text_region = ""
|
||||
next_glue = ""
|
||||
for line_idx in lines_indexer:
|
||||
if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text:
|
||||
continue
|
||||
text_line = extracted_texts_merged[line_idx]
|
||||
if (text_line.endswith(('⸗', '-', '¬')) and
|
||||
# last line of a region can still be wrapped
|
||||
# around columns or pages
|
||||
line_idx < len(lines_indexer) - 1):
|
||||
text_region += next_glue + text_line[:-1]
|
||||
next_glue = ""
|
||||
else:
|
||||
unicode_textregion.text = text_by_textregion[indexer_textregion]
|
||||
indexer_textregion = indexer_textregion + 1
|
||||
text_region += next_glue + text_line
|
||||
next_glue = " "
|
||||
|
||||
region_textequiv = region.find('{%s}TextEquiv' % page_ns)
|
||||
if region_textequiv is None:
|
||||
region_textequiv = ET.SubElement(region, 'TextEquiv')
|
||||
region_teunicode = region_textequiv.find('{%s}Unicode' % page_ns)
|
||||
if region_teunicode is None:
|
||||
region_teunicode = ET.SubElement(region_textequiv, 'Unicode')
|
||||
region_teunicode.text = text_region
|
||||
|
||||
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||
line_textequiv = line.find('{%s}TextEquiv' % page_ns)
|
||||
if line_textequiv is None:
|
||||
line_textequiv = ET.SubElement(line, 'TextEquiv')
|
||||
line_teunicode = line_textequiv.find('{%s}Unicode' % page_ns)
|
||||
if line_teunicode is None:
|
||||
line_teunicode = ET.SubElement(line_textequiv, 'Unicode')
|
||||
|
||||
line_idx = lines_indexer[n_line]
|
||||
if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text:
|
||||
line.remove(line_textequiv)
|
||||
else:
|
||||
line_textequiv.set('conf', str(round(extracted_confs_merged[line_idx], 2)))
|
||||
line_teunicode.text = extracted_texts_merged[line_idx]
|
||||
|
||||
ET.register_namespace("",page_ns)
|
||||
self.logger.info("output filename: '%s'", out_file_ocr)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue