Add multiprocessing to flexible_character_accuracy

pull/47/head
Benjamin Rosemann 4 years ago
parent c4f75d5264
commit b9259b9d01

@ -13,12 +13,12 @@ Note that we deviated from the original algorithm at some places.
import sys
from collections import Counter
from functools import lru_cache, reduce
from functools import lru_cache, reduce, partial
from itertools import product, takewhile
from typing import List, Tuple, Optional
from multiprocessing import cpu_count, get_context
from typing import List, Tuple, Optional, Union
from Levenshtein import editops
from multimethod import multimethod
from . import ExtractedText
@ -38,57 +38,57 @@ else:
)
@multimethod
def flexible_character_accuracy(
gt: ExtractedText, ocr: ExtractedText
gt: Union[str, ExtractedText],
ocr: Union[str, ExtractedText],
n_cpu: int = cpu_count(),
) -> Tuple[float, List[Match]]:
"""Calculate the flexible character accuracy.
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
:param gt: The ground truth ExtractedText object.
:param ocr: The ExtractedText object to compare the ground truth with.
:return: Score between 0 and 1 and match objects.
"""
return flexible_character_accuracy(gt.text, ocr.text)
@multimethod
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
"""Calculate the flexible character accuracy.
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
:param gt: The ground truth text.
:param ocr: The text to compare the ground truth with.
:param n_cpu: numbers of cpus to use for multiprocessing.
:return: Score between 0 and 1 and match objects.
"""
if isinstance(gt, ExtractedText):
gt = gt.text
if isinstance(ocr, ExtractedText):
ocr = ocr.text
best_score = -sys.maxsize
best_matches = []
# TODO: should this be configurable?
combinations = product(
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
)
# TODO: place to parallelize the algorithm?
for (edit_dist, length_diff, offset, length) in combinations:
coef = Coefficients(
coeffs = (
Coefficients(
edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
)
for edit_dist, length_diff, offset, length in product(
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
)
)
with get_context("spawn").Pool(processes=n_cpu) as pool:
# Steps 1 - 6 of the flexible character accuracy algorithm.
matches = match_with_coefficients(gt, ocr, coef)
# Step 7 of the flexible character accuracy algorithm.
score = character_accuracy_for_matches(matches)
if score > best_score:
best_score = score
best_matches = matches
# early breaking: we only need one perfect fit
if best_score >= 1:
break
# We only use multiprocessing if we have more than 2 cpus available.
# Otherwise the overhead for creating processes and filling caches is too big.
map_fun = partial(pool.imap_unordered, chunksize=10) if n_cpu > 2 else map
for matches in map_fun(
partial(match_with_coefficients, gt=gt, ocr=ocr), coeffs
):
# Step 7 of the flexible character accuracy algorithm.
score = character_accuracy_for_matches(matches)
if score > best_score:
best_score = score
best_matches = matches
# early breaking: we only need one perfect fit
if best_score >= 1:
break
return best_score, best_matches
def match_with_coefficients(gt: str, ocr: str, coef: Coefficients) -> List[Match]:
def match_with_coefficients(coef: Coefficients, gt: str, ocr: str) -> List[Match]:
"""Match ground truth with ocr and consider a given set of coefficients.
Reference: contains steps 1 - 6 of the flexible character accuracy algorithm.

@ -104,7 +104,7 @@ def extended_case_to_text(gt, ocr):
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score):
score, _ = flexible_character_accuracy(gt, ocr)
score, _ = flexible_character_accuracy(gt, ocr, 1)
assert score == pytest.approx(all_line_score)
@ -132,7 +132,7 @@ def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_sco
gt_text = get_extracted_text(gt)
ocr_text = get_extracted_text(ocr)
score, _ = flexible_character_accuracy(gt_text, ocr_text)
score, _ = flexible_character_accuracy(gt_text, ocr_text, 1)
assert score == pytest.approx(all_line_score)
@ -186,7 +186,7 @@ def test_flexible_character_accuracy(config, ocr):
)
expected_score = character_accuracy(expected_dist)
result, matches = flexible_character_accuracy(gt, ocr)
result, matches = flexible_character_accuracy(gt, ocr, 1)
agg = reduce(
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
)
@ -201,7 +201,7 @@ def test_flexible_character_accuracy_extended(
):
"""Tests from figure 4 in the 10.1016/j.patrec.2020.02.003."""
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence, 1)
assert result == pytest.approx(all_line_score, abs=0.001)
@ -210,7 +210,7 @@ def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score):
coef = Coefficients()
if not isinstance(gt, str):
gt, ocr = extended_case_to_text(gt, ocr)
matches = match_with_coefficients(gt, ocr, coef)
matches = match_with_coefficients(coef, gt, ocr)
score = character_accuracy_for_matches(matches)
assert score == pytest.approx(all_line_score, abs=0.001)

Loading…
Cancel
Save