Add multiprocessing to flexible_character_accuracy

pull/47/head
Benjamin Rosemann 4 years ago
parent c4f75d5264
commit b9259b9d01

@ -13,12 +13,12 @@ Note that we deviated from the original algorithm at some places.
import sys import sys
from collections import Counter from collections import Counter
from functools import lru_cache, reduce from functools import lru_cache, reduce, partial
from itertools import product, takewhile from itertools import product, takewhile
from typing import List, Tuple, Optional from multiprocessing import cpu_count, get_context
from typing import List, Tuple, Optional, Union
from Levenshtein import editops from Levenshtein import editops
from multimethod import multimethod
from . import ExtractedText from . import ExtractedText
@ -38,57 +38,57 @@ else:
) )
@multimethod
def flexible_character_accuracy( def flexible_character_accuracy(
gt: ExtractedText, ocr: ExtractedText gt: Union[str, ExtractedText],
ocr: Union[str, ExtractedText],
n_cpu: int = cpu_count(),
) -> Tuple[float, List[Match]]: ) -> Tuple[float, List[Match]]:
"""Calculate the flexible character accuracy. """Calculate the flexible character accuracy.
Reference: contains steps 1-7 of the flexible character accuracy algorithm. Reference: contains steps 1-7 of the flexible character accuracy algorithm.
:param gt: The ground truth ExtractedText object.
:param ocr: The ExtractedText object to compare the ground truth with.
:return: Score between 0 and 1 and match objects.
"""
return flexible_character_accuracy(gt.text, ocr.text)
@multimethod
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
"""Calculate the flexible character accuracy.
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
:param gt: The ground truth text. :param gt: The ground truth text.
:param ocr: The text to compare the ground truth with. :param ocr: The text to compare the ground truth with.
:param n_cpu: numbers of cpus to use for multiprocessing.
:return: Score between 0 and 1 and match objects. :return: Score between 0 and 1 and match objects.
""" """
if isinstance(gt, ExtractedText):
gt = gt.text
if isinstance(ocr, ExtractedText):
ocr = ocr.text
best_score = -sys.maxsize best_score = -sys.maxsize
best_matches = [] best_matches = []
# TODO: should this be configurable? # TODO: should this be configurable?
combinations = product( coeffs = (
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1) Coefficients(
)
# TODO: place to parallelize the algorithm?
for (edit_dist, length_diff, offset, length) in combinations:
coef = Coefficients(
edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
) )
for edit_dist, length_diff, offset, length in product(
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
)
)
with get_context("spawn").Pool(processes=n_cpu) as pool:
# Steps 1 - 6 of the flexible character accuracy algorithm. # Steps 1 - 6 of the flexible character accuracy algorithm.
matches = match_with_coefficients(gt, ocr, coef) # We only use multiprocessing if we have more than 2 cpus available.
# Step 7 of the flexible character accuracy algorithm. # Otherwise the overhead for creating processes and filling caches is too big.
score = character_accuracy_for_matches(matches) map_fun = partial(pool.imap_unordered, chunksize=10) if n_cpu > 2 else map
if score > best_score: for matches in map_fun(
best_score = score partial(match_with_coefficients, gt=gt, ocr=ocr), coeffs
best_matches = matches ):
# early breaking: we only need one perfect fit # Step 7 of the flexible character accuracy algorithm.
if best_score >= 1: score = character_accuracy_for_matches(matches)
break if score > best_score:
best_score = score
best_matches = matches
# early breaking: we only need one perfect fit
if best_score >= 1:
break
return best_score, best_matches return best_score, best_matches
def match_with_coefficients(gt: str, ocr: str, coef: Coefficients) -> List[Match]: def match_with_coefficients(coef: Coefficients, gt: str, ocr: str) -> List[Match]:
"""Match ground truth with ocr and consider a given set of coefficients. """Match ground truth with ocr and consider a given set of coefficients.
Reference: contains steps 1 - 6 of the flexible character accuracy algorithm. Reference: contains steps 1 - 6 of the flexible character accuracy algorithm.

@ -104,7 +104,7 @@ def extended_case_to_text(gt, ocr):
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score): def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
score, _ = flexible_character_accuracy(gt, ocr) score, _ = flexible_character_accuracy(gt, ocr, 1)
assert score == pytest.approx(all_line_score) assert score == pytest.approx(all_line_score)
@ -132,7 +132,7 @@ def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_sco
gt_text = get_extracted_text(gt) gt_text = get_extracted_text(gt)
ocr_text = get_extracted_text(ocr) ocr_text = get_extracted_text(ocr)
score, _ = flexible_character_accuracy(gt_text, ocr_text) score, _ = flexible_character_accuracy(gt_text, ocr_text, 1)
assert score == pytest.approx(all_line_score) assert score == pytest.approx(all_line_score)
@ -186,7 +186,7 @@ def test_flexible_character_accuracy(config, ocr):
) )
expected_score = character_accuracy(expected_dist) expected_score = character_accuracy(expected_dist)
result, matches = flexible_character_accuracy(gt, ocr) result, matches = flexible_character_accuracy(gt, ocr, 1)
agg = reduce( agg = reduce(
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
) )
@ -201,7 +201,7 @@ def test_flexible_character_accuracy_extended(
): ):
"""Tests from figure 4 in the 10.1016/j.patrec.2020.02.003.""" """Tests from figure 4 in the 10.1016/j.patrec.2020.02.003."""
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr) gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence) result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence, 1)
assert result == pytest.approx(all_line_score, abs=0.001) assert result == pytest.approx(all_line_score, abs=0.001)
@ -210,7 +210,7 @@ def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score):
coef = Coefficients() coef = Coefficients()
if not isinstance(gt, str): if not isinstance(gt, str):
gt, ocr = extended_case_to_text(gt, ocr) gt, ocr = extended_case_to_text(gt, ocr)
matches = match_with_coefficients(gt, ocr, coef) matches = match_with_coefficients(coef, gt, ocr)
score = character_accuracy_for_matches(matches) score = character_accuracy_for_matches(matches)
assert score == pytest.approx(all_line_score, abs=0.001) assert score == pytest.approx(all_line_score, abs=0.001)

Loading…
Cancel
Save