mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-09 11:50:00 +02:00
Add multiprocessing to flexible_character_accuracy
This commit is contained in:
parent
c4f75d5264
commit
b9259b9d01
2 changed files with 39 additions and 39 deletions
|
@ -13,12 +13,12 @@ Note that we deviated from the original algorithm at some places.
|
|||
|
||||
import sys
|
||||
from collections import Counter
|
||||
from functools import lru_cache, reduce
|
||||
from functools import lru_cache, reduce, partial
|
||||
from itertools import product, takewhile
|
||||
from typing import List, Tuple, Optional
|
||||
from multiprocessing import cpu_count, get_context
|
||||
from typing import List, Tuple, Optional, Union
|
||||
|
||||
from Levenshtein import editops
|
||||
from multimethod import multimethod
|
||||
|
||||
from . import ExtractedText
|
||||
|
||||
|
@ -38,57 +38,57 @@ else:
|
|||
)
|
||||
|
||||
|
||||
@multimethod
|
||||
def flexible_character_accuracy(
|
||||
gt: ExtractedText, ocr: ExtractedText
|
||||
gt: Union[str, ExtractedText],
|
||||
ocr: Union[str, ExtractedText],
|
||||
n_cpu: int = cpu_count(),
|
||||
) -> Tuple[float, List[Match]]:
|
||||
"""Calculate the flexible character accuracy.
|
||||
|
||||
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
|
||||
|
||||
:param gt: The ground truth ExtractedText object.
|
||||
:param ocr: The ExtractedText object to compare the ground truth with.
|
||||
:return: Score between 0 and 1 and match objects.
|
||||
"""
|
||||
return flexible_character_accuracy(gt.text, ocr.text)
|
||||
|
||||
|
||||
@multimethod
|
||||
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
|
||||
"""Calculate the flexible character accuracy.
|
||||
|
||||
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
|
||||
|
||||
:param gt: The ground truth text.
|
||||
:param ocr: The text to compare the ground truth with.
|
||||
:param n_cpu: numbers of cpus to use for multiprocessing.
|
||||
:return: Score between 0 and 1 and match objects.
|
||||
"""
|
||||
|
||||
if isinstance(gt, ExtractedText):
|
||||
gt = gt.text
|
||||
if isinstance(ocr, ExtractedText):
|
||||
ocr = ocr.text
|
||||
|
||||
best_score = -sys.maxsize
|
||||
best_matches = []
|
||||
# TODO: should this be configurable?
|
||||
combinations = product(
|
||||
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
|
||||
)
|
||||
# TODO: place to parallelize the algorithm?
|
||||
for (edit_dist, length_diff, offset, length) in combinations:
|
||||
coef = Coefficients(
|
||||
coeffs = (
|
||||
Coefficients(
|
||||
edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
|
||||
)
|
||||
for edit_dist, length_diff, offset, length in product(
|
||||
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
|
||||
)
|
||||
)
|
||||
with get_context("spawn").Pool(processes=n_cpu) as pool:
|
||||
# Steps 1 - 6 of the flexible character accuracy algorithm.
|
||||
matches = match_with_coefficients(gt, ocr, coef)
|
||||
# Step 7 of the flexible character accuracy algorithm.
|
||||
score = character_accuracy_for_matches(matches)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_matches = matches
|
||||
# early breaking: we only need one perfect fit
|
||||
if best_score >= 1:
|
||||
break
|
||||
# We only use multiprocessing if we have more than 2 cpus available.
|
||||
# Otherwise the overhead for creating processes and filling caches is too big.
|
||||
map_fun = partial(pool.imap_unordered, chunksize=10) if n_cpu > 2 else map
|
||||
for matches in map_fun(
|
||||
partial(match_with_coefficients, gt=gt, ocr=ocr), coeffs
|
||||
):
|
||||
# Step 7 of the flexible character accuracy algorithm.
|
||||
score = character_accuracy_for_matches(matches)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_matches = matches
|
||||
# early breaking: we only need one perfect fit
|
||||
if best_score >= 1:
|
||||
break
|
||||
return best_score, best_matches
|
||||
|
||||
|
||||
def match_with_coefficients(gt: str, ocr: str, coef: Coefficients) -> List[Match]:
|
||||
def match_with_coefficients(coef: Coefficients, gt: str, ocr: str) -> List[Match]:
|
||||
"""Match ground truth with ocr and consider a given set of coefficients.
|
||||
|
||||
Reference: contains steps 1 - 6 of the flexible character accuracy algorithm.
|
||||
|
|
|
@ -104,7 +104,7 @@ def extended_case_to_text(gt, ocr):
|
|||
|
||||
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
|
||||
def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
|
||||
score, _ = flexible_character_accuracy(gt, ocr)
|
||||
score, _ = flexible_character_accuracy(gt, ocr, 1)
|
||||
assert score == pytest.approx(all_line_score)
|
||||
|
||||
|
||||
|
@ -132,7 +132,7 @@ def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_sco
|
|||
|
||||
gt_text = get_extracted_text(gt)
|
||||
ocr_text = get_extracted_text(ocr)
|
||||
score, _ = flexible_character_accuracy(gt_text, ocr_text)
|
||||
score, _ = flexible_character_accuracy(gt_text, ocr_text, 1)
|
||||
assert score == pytest.approx(all_line_score)
|
||||
|
||||
|
||||
|
@ -186,7 +186,7 @@ def test_flexible_character_accuracy(config, ocr):
|
|||
)
|
||||
expected_score = character_accuracy(expected_dist)
|
||||
|
||||
result, matches = flexible_character_accuracy(gt, ocr)
|
||||
result, matches = flexible_character_accuracy(gt, ocr, 1)
|
||||
agg = reduce(
|
||||
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
|
||||
)
|
||||
|
@ -201,7 +201,7 @@ def test_flexible_character_accuracy_extended(
|
|||
):
|
||||
"""Tests from figure 4 in the 10.1016/j.patrec.2020.02.003."""
|
||||
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
|
||||
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
|
||||
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence, 1)
|
||||
assert result == pytest.approx(all_line_score, abs=0.001)
|
||||
|
||||
|
||||
|
@ -210,7 +210,7 @@ def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score):
|
|||
coef = Coefficients()
|
||||
if not isinstance(gt, str):
|
||||
gt, ocr = extended_case_to_text(gt, ocr)
|
||||
matches = match_with_coefficients(gt, ocr, coef)
|
||||
matches = match_with_coefficients(coef, gt, ocr)
|
||||
score = character_accuracy_for_matches(matches)
|
||||
assert score == pytest.approx(all_line_score, abs=0.001)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue