mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-08 11:20:26 +02:00
Add multiprocessing to flexible_character_accuracy
This commit is contained in:
parent
c4f75d5264
commit
b9259b9d01
2 changed files with 39 additions and 39 deletions
|
@ -13,12 +13,12 @@ Note that we deviated from the original algorithm at some places.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from functools import lru_cache, reduce
|
from functools import lru_cache, reduce, partial
|
||||||
from itertools import product, takewhile
|
from itertools import product, takewhile
|
||||||
from typing import List, Tuple, Optional
|
from multiprocessing import cpu_count, get_context
|
||||||
|
from typing import List, Tuple, Optional, Union
|
||||||
|
|
||||||
from Levenshtein import editops
|
from Levenshtein import editops
|
||||||
from multimethod import multimethod
|
|
||||||
|
|
||||||
from . import ExtractedText
|
from . import ExtractedText
|
||||||
|
|
||||||
|
@ -38,45 +38,45 @@ else:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
|
||||||
def flexible_character_accuracy(
|
def flexible_character_accuracy(
|
||||||
gt: ExtractedText, ocr: ExtractedText
|
gt: Union[str, ExtractedText],
|
||||||
|
ocr: Union[str, ExtractedText],
|
||||||
|
n_cpu: int = cpu_count(),
|
||||||
) -> Tuple[float, List[Match]]:
|
) -> Tuple[float, List[Match]]:
|
||||||
"""Calculate the flexible character accuracy.
|
"""Calculate the flexible character accuracy.
|
||||||
|
|
||||||
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
|
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
|
||||||
|
|
||||||
:param gt: The ground truth ExtractedText object.
|
|
||||||
:param ocr: The ExtractedText object to compare the ground truth with.
|
|
||||||
:return: Score between 0 and 1 and match objects.
|
|
||||||
"""
|
|
||||||
return flexible_character_accuracy(gt.text, ocr.text)
|
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
|
||||||
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
|
|
||||||
"""Calculate the flexible character accuracy.
|
|
||||||
|
|
||||||
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
|
|
||||||
|
|
||||||
:param gt: The ground truth text.
|
:param gt: The ground truth text.
|
||||||
:param ocr: The text to compare the ground truth with.
|
:param ocr: The text to compare the ground truth with.
|
||||||
|
:param n_cpu: numbers of cpus to use for multiprocessing.
|
||||||
:return: Score between 0 and 1 and match objects.
|
:return: Score between 0 and 1 and match objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(gt, ExtractedText):
|
||||||
|
gt = gt.text
|
||||||
|
if isinstance(ocr, ExtractedText):
|
||||||
|
ocr = ocr.text
|
||||||
|
|
||||||
best_score = -sys.maxsize
|
best_score = -sys.maxsize
|
||||||
best_matches = []
|
best_matches = []
|
||||||
# TODO: should this be configurable?
|
# TODO: should this be configurable?
|
||||||
combinations = product(
|
coeffs = (
|
||||||
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
|
Coefficients(
|
||||||
)
|
|
||||||
# TODO: place to parallelize the algorithm?
|
|
||||||
for (edit_dist, length_diff, offset, length) in combinations:
|
|
||||||
coef = Coefficients(
|
|
||||||
edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
|
edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
|
||||||
)
|
)
|
||||||
|
for edit_dist, length_diff, offset, length in product(
|
||||||
|
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with get_context("spawn").Pool(processes=n_cpu) as pool:
|
||||||
# Steps 1 - 6 of the flexible character accuracy algorithm.
|
# Steps 1 - 6 of the flexible character accuracy algorithm.
|
||||||
matches = match_with_coefficients(gt, ocr, coef)
|
# We only use multiprocessing if we have more than 2 cpus available.
|
||||||
|
# Otherwise the overhead for creating processes and filling caches is too big.
|
||||||
|
map_fun = partial(pool.imap_unordered, chunksize=10) if n_cpu > 2 else map
|
||||||
|
for matches in map_fun(
|
||||||
|
partial(match_with_coefficients, gt=gt, ocr=ocr), coeffs
|
||||||
|
):
|
||||||
# Step 7 of the flexible character accuracy algorithm.
|
# Step 7 of the flexible character accuracy algorithm.
|
||||||
score = character_accuracy_for_matches(matches)
|
score = character_accuracy_for_matches(matches)
|
||||||
if score > best_score:
|
if score > best_score:
|
||||||
|
@ -88,7 +88,7 @@ def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
|
||||||
return best_score, best_matches
|
return best_score, best_matches
|
||||||
|
|
||||||
|
|
||||||
def match_with_coefficients(gt: str, ocr: str, coef: Coefficients) -> List[Match]:
|
def match_with_coefficients(coef: Coefficients, gt: str, ocr: str) -> List[Match]:
|
||||||
"""Match ground truth with ocr and consider a given set of coefficients.
|
"""Match ground truth with ocr and consider a given set of coefficients.
|
||||||
|
|
||||||
Reference: contains steps 1 - 6 of the flexible character accuracy algorithm.
|
Reference: contains steps 1 - 6 of the flexible character accuracy algorithm.
|
||||||
|
|
|
@ -104,7 +104,7 @@ def extended_case_to_text(gt, ocr):
|
||||||
|
|
||||||
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
|
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
|
||||||
def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
|
def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
|
||||||
score, _ = flexible_character_accuracy(gt, ocr)
|
score, _ = flexible_character_accuracy(gt, ocr, 1)
|
||||||
assert score == pytest.approx(all_line_score)
|
assert score == pytest.approx(all_line_score)
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_sco
|
||||||
|
|
||||||
gt_text = get_extracted_text(gt)
|
gt_text = get_extracted_text(gt)
|
||||||
ocr_text = get_extracted_text(ocr)
|
ocr_text = get_extracted_text(ocr)
|
||||||
score, _ = flexible_character_accuracy(gt_text, ocr_text)
|
score, _ = flexible_character_accuracy(gt_text, ocr_text, 1)
|
||||||
assert score == pytest.approx(all_line_score)
|
assert score == pytest.approx(all_line_score)
|
||||||
|
|
||||||
|
|
||||||
|
@ -186,7 +186,7 @@ def test_flexible_character_accuracy(config, ocr):
|
||||||
)
|
)
|
||||||
expected_score = character_accuracy(expected_dist)
|
expected_score = character_accuracy(expected_dist)
|
||||||
|
|
||||||
result, matches = flexible_character_accuracy(gt, ocr)
|
result, matches = flexible_character_accuracy(gt, ocr, 1)
|
||||||
agg = reduce(
|
agg = reduce(
|
||||||
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
|
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
|
||||||
)
|
)
|
||||||
|
@ -201,7 +201,7 @@ def test_flexible_character_accuracy_extended(
|
||||||
):
|
):
|
||||||
"""Tests from figure 4 in the 10.1016/j.patrec.2020.02.003."""
|
"""Tests from figure 4 in the 10.1016/j.patrec.2020.02.003."""
|
||||||
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
|
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
|
||||||
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
|
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence, 1)
|
||||||
assert result == pytest.approx(all_line_score, abs=0.001)
|
assert result == pytest.approx(all_line_score, abs=0.001)
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score):
|
||||||
coef = Coefficients()
|
coef = Coefficients()
|
||||||
if not isinstance(gt, str):
|
if not isinstance(gt, str):
|
||||||
gt, ocr = extended_case_to_text(gt, ocr)
|
gt, ocr = extended_case_to_text(gt, ocr)
|
||||||
matches = match_with_coefficients(gt, ocr, coef)
|
matches = match_with_coefficients(coef, gt, ocr)
|
||||||
score = character_accuracy_for_matches(matches)
|
score = character_accuracy_for_matches(matches)
|
||||||
assert score == pytest.approx(all_line_score, abs=0.001)
|
assert score == pytest.approx(all_line_score, abs=0.001)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue