mirror of
				https://github.com/qurator-spk/dinglehopper.git
				synced 2025-10-31 09:24:15 +01:00 
			
		
		
		
	Add multiprocessing to flexible_character_accuracy
This commit is contained in:
		
							parent
							
								
									c4f75d5264
								
							
						
					
					
						commit
						b9259b9d01
					
				
					 2 changed files with 39 additions and 39 deletions
				
			
		|  | @ -13,12 +13,12 @@ Note that we deviated from the original algorithm at some places. | |||
| 
 | ||||
| import sys | ||||
| from collections import Counter | ||||
| from functools import lru_cache, reduce | ||||
| from functools import lru_cache, reduce, partial | ||||
| from itertools import product, takewhile | ||||
| from typing import List, Tuple, Optional | ||||
| from multiprocessing import cpu_count, get_context | ||||
| from typing import List, Tuple, Optional, Union | ||||
| 
 | ||||
| from Levenshtein import editops | ||||
| from multimethod import multimethod | ||||
| 
 | ||||
| from . import ExtractedText | ||||
| 
 | ||||
|  | @ -38,57 +38,57 @@ else: | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @multimethod | ||||
| def flexible_character_accuracy( | ||||
|     gt: ExtractedText, ocr: ExtractedText | ||||
|     gt: Union[str, ExtractedText], | ||||
|     ocr: Union[str, ExtractedText], | ||||
|     n_cpu: int = cpu_count(), | ||||
| ) -> Tuple[float, List[Match]]: | ||||
|     """Calculate the flexible character accuracy. | ||||
| 
 | ||||
|     Reference: contains steps 1-7 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     :param gt: The ground truth ExtractedText object. | ||||
|     :param ocr: The ExtractedText object to compare the ground truth with. | ||||
|     :return: Score between 0 and 1 and match objects. | ||||
|     """ | ||||
|     return flexible_character_accuracy(gt.text, ocr.text) | ||||
| 
 | ||||
| 
 | ||||
| @multimethod | ||||
| def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]: | ||||
|     """Calculate the flexible character accuracy. | ||||
| 
 | ||||
|     Reference: contains steps 1-7 of the flexible character accuracy algorithm. | ||||
| 
 | ||||
|     :param gt: The ground truth text. | ||||
|     :param ocr: The text to compare the ground truth with. | ||||
|     :param n_cpu: numbers of cpus to use for multiprocessing. | ||||
|     :return: Score between 0 and 1 and match objects. | ||||
|     """ | ||||
| 
 | ||||
|     if isinstance(gt, ExtractedText): | ||||
|         gt = gt.text | ||||
|     if isinstance(ocr, ExtractedText): | ||||
|         ocr = ocr.text | ||||
| 
 | ||||
|     best_score = -sys.maxsize | ||||
|     best_matches = [] | ||||
|     # TODO: should this be configurable? | ||||
|     combinations = product( | ||||
|         range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1) | ||||
|     ) | ||||
|     # TODO: place to parallelize the algorithm? | ||||
|     for (edit_dist, length_diff, offset, length) in combinations: | ||||
|         coef = Coefficients( | ||||
|     coeffs = ( | ||||
|         Coefficients( | ||||
|             edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length | ||||
|         ) | ||||
|         for edit_dist, length_diff, offset, length in product( | ||||
|             range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1) | ||||
|         ) | ||||
|     ) | ||||
|     with get_context("spawn").Pool(processes=n_cpu) as pool: | ||||
|         # Steps 1 - 6 of the flexible character accuracy algorithm. | ||||
|         matches = match_with_coefficients(gt, ocr, coef) | ||||
|         # Step 7 of the flexible character accuracy algorithm. | ||||
|         score = character_accuracy_for_matches(matches) | ||||
|         if score > best_score: | ||||
|             best_score = score | ||||
|             best_matches = matches | ||||
|         # early breaking: we only need one perfect fit | ||||
|         if best_score >= 1: | ||||
|             break | ||||
|         # We only use multiprocessing if we have more than 2 cpus available. | ||||
|         # Otherwise the overhead for creating processes and filling caches is too big. | ||||
|         map_fun = partial(pool.imap_unordered, chunksize=10) if n_cpu > 2 else map | ||||
|         for matches in map_fun( | ||||
|             partial(match_with_coefficients, gt=gt, ocr=ocr), coeffs | ||||
|         ): | ||||
|             # Step 7 of the flexible character accuracy algorithm. | ||||
|             score = character_accuracy_for_matches(matches) | ||||
|             if score > best_score: | ||||
|                 best_score = score | ||||
|                 best_matches = matches | ||||
|             # early breaking: we only need one perfect fit | ||||
|             if best_score >= 1: | ||||
|                 break | ||||
|     return best_score, best_matches | ||||
| 
 | ||||
| 
 | ||||
| def match_with_coefficients(gt: str, ocr: str, coef: Coefficients) -> List[Match]: | ||||
| def match_with_coefficients(coef: Coefficients, gt: str, ocr: str) -> List[Match]: | ||||
|     """Match ground truth with ocr and consider a given set of coefficients. | ||||
| 
 | ||||
|     Reference: contains steps 1 - 6 of the flexible character accuracy algorithm. | ||||
|  |  | |||
|  | @ -104,7 +104,7 @@ def extended_case_to_text(gt, ocr): | |||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) | ||||
| def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score): | ||||
|     score, _ = flexible_character_accuracy(gt, ocr) | ||||
|     score, _ = flexible_character_accuracy(gt, ocr, 1) | ||||
|     assert score == pytest.approx(all_line_score) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -132,7 +132,7 @@ def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_sco | |||
| 
 | ||||
|     gt_text = get_extracted_text(gt) | ||||
|     ocr_text = get_extracted_text(ocr) | ||||
|     score, _ = flexible_character_accuracy(gt_text, ocr_text) | ||||
|     score, _ = flexible_character_accuracy(gt_text, ocr_text, 1) | ||||
|     assert score == pytest.approx(all_line_score) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -186,7 +186,7 @@ def test_flexible_character_accuracy(config, ocr): | |||
|     ) | ||||
|     expected_score = character_accuracy(expected_dist) | ||||
| 
 | ||||
|     result, matches = flexible_character_accuracy(gt, ocr) | ||||
|     result, matches = flexible_character_accuracy(gt, ocr, 1) | ||||
|     agg = reduce( | ||||
|         lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() | ||||
|     ) | ||||
|  | @ -201,7 +201,7 @@ def test_flexible_character_accuracy_extended( | |||
| ): | ||||
|     """Tests from figure 4 in the 10.1016/j.patrec.2020.02.003.""" | ||||
|     gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr) | ||||
|     result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence) | ||||
|     result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence, 1) | ||||
|     assert result == pytest.approx(all_line_score, abs=0.001) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -210,7 +210,7 @@ def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score): | |||
|     coef = Coefficients() | ||||
|     if not isinstance(gt, str): | ||||
|         gt, ocr = extended_case_to_text(gt, ocr) | ||||
|     matches = match_with_coefficients(gt, ocr, coef) | ||||
|     matches = match_with_coefficients(coef, gt, ocr) | ||||
|     score = character_accuracy_for_matches(matches) | ||||
|     assert score == pytest.approx(all_line_score, abs=0.001) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue