From b9259b9d01eae9b34c7d214326a0c078a4a44e72 Mon Sep 17 00:00:00 2001 From: Benjamin Rosemann Date: Thu, 26 Nov 2020 09:58:40 +0100 Subject: [PATCH] Add multiprocessing to flexible_character_accuracy --- .../flexible_character_accuracy.py | 68 +++++++++---------- .../tests/test_flexible_character_accuracy.py | 10 +-- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/qurator/dinglehopper/flexible_character_accuracy.py b/qurator/dinglehopper/flexible_character_accuracy.py index 241ef4a..4ace63c 100644 --- a/qurator/dinglehopper/flexible_character_accuracy.py +++ b/qurator/dinglehopper/flexible_character_accuracy.py @@ -13,12 +13,12 @@ Note that we deviated from the original algorithm at some places. import sys from collections import Counter -from functools import lru_cache, reduce +from functools import lru_cache, reduce, partial from itertools import product, takewhile -from typing import List, Tuple, Optional +from multiprocessing import cpu_count, get_context +from typing import List, Tuple, Optional, Union from Levenshtein import editops -from multimethod import multimethod from . import ExtractedText @@ -38,57 +38,57 @@ else: ) -@multimethod def flexible_character_accuracy( - gt: ExtractedText, ocr: ExtractedText + gt: Union[str, ExtractedText], + ocr: Union[str, ExtractedText], + n_cpu: int = cpu_count(), ) -> Tuple[float, List[Match]]: """Calculate the flexible character accuracy. Reference: contains steps 1-7 of the flexible character accuracy algorithm. - :param gt: The ground truth ExtractedText object. - :param ocr: The ExtractedText object to compare the ground truth with. - :return: Score between 0 and 1 and match objects. - """ - return flexible_character_accuracy(gt.text, ocr.text) - - -@multimethod -def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]: - """Calculate the flexible character accuracy. - - Reference: contains steps 1-7 of the flexible character accuracy algorithm. - :param gt: The ground truth text. :param ocr: The text to compare the ground truth with. + :param n_cpu: numbers of cpus to use for multiprocessing. :return: Score between 0 and 1 and match objects. """ + if isinstance(gt, ExtractedText): + gt = gt.text + if isinstance(ocr, ExtractedText): + ocr = ocr.text + best_score = -sys.maxsize best_matches = [] # TODO: should this be configurable? - combinations = product( - range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1) - ) - # TODO: place to parallelize the algorithm? - for (edit_dist, length_diff, offset, length) in combinations: - coef = Coefficients( + coeffs = ( + Coefficients( edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length ) + for edit_dist, length_diff, offset, length in product( + range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1) + ) + ) + with get_context("spawn").Pool(processes=n_cpu) as pool: # Steps 1 - 6 of the flexible character accuracy algorithm. - matches = match_with_coefficients(gt, ocr, coef) - # Step 7 of the flexible character accuracy algorithm. - score = character_accuracy_for_matches(matches) - if score > best_score: - best_score = score - best_matches = matches - # early breaking: we only need one perfect fit - if best_score >= 1: - break + # We only use multiprocessing if we have more than 2 cpus available. + # Otherwise the overhead for creating processes and filling caches is too big. + map_fun = partial(pool.imap_unordered, chunksize=10) if n_cpu > 2 else map + for matches in map_fun( + partial(match_with_coefficients, gt=gt, ocr=ocr), coeffs + ): + # Step 7 of the flexible character accuracy algorithm. + score = character_accuracy_for_matches(matches) + if score > best_score: + best_score = score + best_matches = matches + # early breaking: we only need one perfect fit + if best_score >= 1: + break return best_score, best_matches -def match_with_coefficients(gt: str, ocr: str, coef: Coefficients) -> List[Match]: +def match_with_coefficients(coef: Coefficients, gt: str, ocr: str) -> List[Match]: """Match ground truth with ocr and consider a given set of coefficients. Reference: contains steps 1 - 6 of the flexible character accuracy algorithm. diff --git a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py index ad62798..6ef316b 100644 --- a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py +++ b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py @@ -104,7 +104,7 @@ def extended_case_to_text(gt, ocr): @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score): - score, _ = flexible_character_accuracy(gt, ocr) + score, _ = flexible_character_accuracy(gt, ocr, 1) assert score == pytest.approx(all_line_score) @@ -132,7 +132,7 @@ def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_sco gt_text = get_extracted_text(gt) ocr_text = get_extracted_text(ocr) - score, _ = flexible_character_accuracy(gt_text, ocr_text) + score, _ = flexible_character_accuracy(gt_text, ocr_text, 1) assert score == pytest.approx(all_line_score) @@ -186,7 +186,7 @@ def test_flexible_character_accuracy(config, ocr): ) expected_score = character_accuracy(expected_dist) - result, matches = flexible_character_accuracy(gt, ocr) + result, matches = flexible_character_accuracy(gt, ocr, 1) agg = reduce( lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() ) @@ -201,7 +201,7 @@ def test_flexible_character_accuracy_extended( ): """Tests from figure 4 in the 10.1016/j.patrec.2020.02.003.""" gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr) - result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence) + result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence, 1) assert result == pytest.approx(all_line_score, abs=0.001) @@ -210,7 +210,7 @@ def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score): coef = Coefficients() if not isinstance(gt, str): gt, ocr = extended_case_to_text(gt, ocr) - matches = match_with_coefficients(gt, ocr, coef) + matches = match_with_coefficients(coef, gt, ocr) score = character_accuracy_for_matches(matches) assert score == pytest.approx(all_line_score, abs=0.001)