mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
Eynollah_ocr: correctly handle min_conf, improve writer…
- `min_conf_value_of_textline_text`: apply by skipping lines below threshold (instead of writing empty text), and delete their TextEquiv (if existing) - `write_ocr()`: simplify, and ensure consistency between line and region level text correctly
This commit is contained in:
parent
8ffc4ed8d3
commit
d2f2a1e06b
1 changed files with 50 additions and 87 deletions
|
|
@ -41,7 +41,7 @@ from .utils.utils_ocr import (
|
||||||
@dataclass
|
@dataclass
|
||||||
class EynollahOcrResult:
|
class EynollahOcrResult:
|
||||||
extracted_texts_merged: List
|
extracted_texts_merged: List
|
||||||
extracted_confs_merged: Optional[List]
|
extracted_confs_merged: List
|
||||||
cropped_lines_region_indexer: List
|
cropped_lines_region_indexer: List
|
||||||
total_bb_coordinates:List
|
total_bb_coordinates:List
|
||||||
|
|
||||||
|
|
@ -156,10 +156,6 @@ class Eynollah_ocr(Eynollah):
|
||||||
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
|
conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist()
|
||||||
else:
|
else:
|
||||||
conf = [1.0] * len(output.sequences)
|
conf = [1.0] * len(output.sequences)
|
||||||
if conf < self.min_conf_value_of_textline_text:
|
|
||||||
extracted_confs.extend(0)
|
|
||||||
extracted_texts.extend("")
|
|
||||||
continue
|
|
||||||
text = self.model_zoo.get('trocr_processor').batch_decode(
|
text = self.model_zoo.get('trocr_processor').batch_decode(
|
||||||
output.sequences,
|
output.sequences,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
|
|
@ -327,10 +323,6 @@ class Eynollah_ocr(Eynollah):
|
||||||
def nooov(x):
|
def nooov(x):
|
||||||
return x != b'[UNK]'
|
return x != b'[UNK]'
|
||||||
for pred, prob in zip(preds, probs):
|
for pred, prob in zip(preds, probs):
|
||||||
if prob < self.min_conf_value_of_textline_text:
|
|
||||||
extracted_texts.append("")
|
|
||||||
extracted_confs.append(0)
|
|
||||||
else:
|
|
||||||
text = b''.join(
|
text = b''.join(
|
||||||
filter(nooov,
|
filter(nooov,
|
||||||
map(bytes,
|
map(bytes,
|
||||||
|
|
@ -375,7 +367,6 @@ class Eynollah_ocr(Eynollah):
|
||||||
extracted_texts_merged = result.extracted_texts_merged
|
extracted_texts_merged = result.extracted_texts_merged
|
||||||
extracted_confs_merged = result.extracted_confs_merged
|
extracted_confs_merged = result.extracted_confs_merged
|
||||||
|
|
||||||
unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer)
|
|
||||||
if out_image_with_text:
|
if out_image_with_text:
|
||||||
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
image_text = Image.new("RGB", (img.shape[1], img.shape[0]), "white")
|
||||||
draw = ImageDraw.Draw(image_text)
|
draw = ImageDraw.Draw(image_text)
|
||||||
|
|
@ -403,78 +394,50 @@ class Eynollah_ocr(Eynollah):
|
||||||
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
|
draw.text((text_x, text_y), extracted_texts_merged[indexer_text], fill="black", font=font)
|
||||||
image_text.save(out_image_with_text)
|
image_text.save(out_image_with_text)
|
||||||
|
|
||||||
text_by_textregion = []
|
cropped_lines_region_indexer = np.array(cropped_lines_region_indexer)
|
||||||
for ind in unique_cropped_lines_region_indexer:
|
for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)):
|
||||||
ind = np.array(cropped_lines_region_indexer)==ind
|
lines_indexer = np.flatnonzero(cropped_lines_region_indexer == n_region)
|
||||||
extracted_texts_merged_un = np.array(extracted_texts_merged)[ind]
|
if not len(lines_indexer):
|
||||||
if len(extracted_texts_merged_un)>1:
|
continue
|
||||||
text_by_textregion_ind = ""
|
|
||||||
|
text_region = ""
|
||||||
next_glue = ""
|
next_glue = ""
|
||||||
for indt in range(len(extracted_texts_merged_un)):
|
for line_idx in lines_indexer:
|
||||||
if (extracted_texts_merged_un[indt].endswith('⸗') or
|
if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text:
|
||||||
extracted_texts_merged_un[indt].endswith('-') or
|
continue
|
||||||
extracted_texts_merged_un[indt].endswith('¬')):
|
text_line = extracted_texts_merged[line_idx]
|
||||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt][:-1]
|
if (text_line.endswith(('⸗', '-', '¬')) and
|
||||||
|
# last line of a region can still be wrapped
|
||||||
|
# around columns or pages
|
||||||
|
line_idx < len(lines_indexer) - 1):
|
||||||
|
text_region += next_glue + text_line[:-1]
|
||||||
next_glue = ""
|
next_glue = ""
|
||||||
else:
|
else:
|
||||||
text_by_textregion_ind += next_glue + extracted_texts_merged_un[indt]
|
text_region += next_glue + text_line
|
||||||
next_glue = " "
|
next_glue = " "
|
||||||
text_by_textregion.append(text_by_textregion_ind)
|
|
||||||
|
region_textequiv = region.find('{%s}TextEquiv' % page_ns)
|
||||||
|
if region_textequiv is None:
|
||||||
|
region_textequiv = ET.SubElement(region, 'TextEquiv')
|
||||||
|
region_teunicode = region_textequiv.find('{%s}Unicode' % page_ns)
|
||||||
|
if region_teunicode is None:
|
||||||
|
region_teunicode = ET.SubElement(region_textequiv, 'Unicode')
|
||||||
|
region_teunicode.text = text_region
|
||||||
|
|
||||||
|
for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)):
|
||||||
|
line_textequiv = line.find('{%s}TextEquiv' % page_ns)
|
||||||
|
if line_textequiv is None:
|
||||||
|
line_textequiv = ET.SubElement(line, 'TextEquiv')
|
||||||
|
line_teunicode = line_textequiv.find('{%s}Unicode' % page_ns)
|
||||||
|
if line_teunicode is None:
|
||||||
|
line_teunicode = ET.SubElement(line_textequiv, 'Unicode')
|
||||||
|
|
||||||
|
line_idx = lines_indexer[n_line]
|
||||||
|
if extracted_confs_merged[line_idx] < self.min_conf_value_of_textline_text:
|
||||||
|
line.remove(line_textequiv)
|
||||||
else:
|
else:
|
||||||
text_by_textregion.append(" ".join(extracted_texts_merged_un))
|
line_textequiv.set('conf', str(round(extracted_confs_merged[line_idx], 2)))
|
||||||
|
line_teunicode.text = extracted_texts_merged[line_idx]
|
||||||
indexer = 0
|
|
||||||
indexer_textregion = 0
|
|
||||||
for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'):
|
|
||||||
|
|
||||||
is_textregion_text = False
|
|
||||||
for childtest in nn:
|
|
||||||
if childtest.tag.endswith("TextEquiv"):
|
|
||||||
is_textregion_text = True
|
|
||||||
|
|
||||||
if not is_textregion_text:
|
|
||||||
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
|
|
||||||
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
|
|
||||||
|
|
||||||
|
|
||||||
has_textline = False
|
|
||||||
for child_textregion in nn:
|
|
||||||
# FIXME: should remove Word level, if it already exists
|
|
||||||
if child_textregion.tag.endswith("TextLine"):
|
|
||||||
|
|
||||||
is_textline_text = False
|
|
||||||
for childtest2 in child_textregion:
|
|
||||||
if childtest2.tag.endswith("TextEquiv"):
|
|
||||||
is_textline_text = True
|
|
||||||
|
|
||||||
|
|
||||||
if not is_textline_text:
|
|
||||||
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
|
|
||||||
if extracted_confs_merged:
|
|
||||||
text_subelement.set('conf', f"{extracted_confs_merged[indexer]:.2f}")
|
|
||||||
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
|
|
||||||
unicode_textline.text = extracted_texts_merged[indexer]
|
|
||||||
else:
|
|
||||||
for childtest3 in child_textregion:
|
|
||||||
if childtest3.tag.endswith("TextEquiv"):
|
|
||||||
for child_uc in childtest3:
|
|
||||||
if child_uc.tag.endswith("Unicode"):
|
|
||||||
if extracted_confs_merged:
|
|
||||||
childtest3.set('conf', f"{extracted_confs_merged[indexer]:.2f}")
|
|
||||||
child_uc.text = extracted_texts_merged[indexer]
|
|
||||||
|
|
||||||
indexer = indexer + 1
|
|
||||||
has_textline = True
|
|
||||||
if has_textline:
|
|
||||||
if is_textregion_text:
|
|
||||||
for child4 in nn:
|
|
||||||
if child4.tag.endswith("TextEquiv"):
|
|
||||||
for childtr_uc in child4:
|
|
||||||
if childtr_uc.tag.endswith("Unicode"):
|
|
||||||
childtr_uc.text = text_by_textregion[indexer_textregion]
|
|
||||||
else:
|
|
||||||
unicode_textregion.text = text_by_textregion[indexer_textregion]
|
|
||||||
indexer_textregion = indexer_textregion + 1
|
|
||||||
|
|
||||||
ET.register_namespace("",page_ns)
|
ET.register_namespace("",page_ns)
|
||||||
self.logger.info("output filename: '%s'", out_file_ocr)
|
self.logger.info("output filename: '%s'", out_file_ocr)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue