mirror of
				https://github.com/qurator-spk/dinglehopper.git
				synced 2025-10-31 17:34:15 +01:00 
			
		
		
		
	Reformat using black
This commit is contained in:
		
							parent
							
								
									5277593bdb
								
							
						
					
					
						commit
						2a215a1062
					
				
					 2 changed files with 273 additions and 171 deletions
				
			
		|  | @ -29,20 +29,16 @@ def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"] | |||
|     :return: Score between 0 and 1 and match objects. | ||||
|     """ | ||||
| 
 | ||||
|     best_score = -float('inf') | ||||
|     best_score = -float("inf") | ||||
|     best_matches = [] | ||||
|     # TODO: this should be configurable | ||||
|     combinations = product(range(15, 31, 5), | ||||
|                            range(0, 24, 3), | ||||
|                            range(0, 4, 1), | ||||
|                            range(0, 6, 1)) | ||||
|     combinations = product( | ||||
|         range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1) | ||||
|     ) | ||||
|     # TODO: place to parallelize the algorithm | ||||
|     for (edit_dist, length_diff, offset, length) in combinations: | ||||
|         coef = Coefficients( | ||||
|             edit_dist=edit_dist, | ||||
|             length_diff=length_diff, | ||||
|             offset=offset, | ||||
|             length=length | ||||
|             edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length | ||||
|         ) | ||||
|         # Steps 1 - 6 of the flexible character accuracy algorithm. | ||||
|         matches = match_with_coefficients(gt, ocr, coef) | ||||
|  | @ -79,17 +75,21 @@ def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Ma | |||
| 
 | ||||
|     # Step 6 of the flexible character accuracy algorithm. | ||||
|     # remaining lines are considered as deletes and inserts | ||||
|     deletes = [distance(line, Part(text="", line=line.line, start=line.start)) | ||||
|                for line in gt_lines] | ||||
|     inserts = [distance(Part(text="", line=line.line, start=line.start), line) | ||||
|                for line in ocr_lines] | ||||
|     deletes = [ | ||||
|         distance(line, Part(text="", line=line.line, start=line.start)) | ||||
|         for line in gt_lines | ||||
|     ] | ||||
|     inserts = [ | ||||
|         distance(Part(text="", line=line.line, start=line.start), line) | ||||
|         for line in ocr_lines | ||||
|     ] | ||||
| 
 | ||||
|     return [*matches, *deletes, *inserts] | ||||
| 
 | ||||
| 
 | ||||
| def match_longest_gt_lines(gt_lines: List["Part"], | ||||
|                            ocr_lines: List["Part"], | ||||
|                            coef: "Coefficients") -> Optional["Match"]: | ||||
| def match_longest_gt_lines( | ||||
|     gt_lines: List["Part"], ocr_lines: List["Part"], coef: "Coefficients" | ||||
| ) -> Optional["Match"]: | ||||
|     """Find the best match for the longest line(s) in ground truth. | ||||
| 
 | ||||
|     The longest lines in ground truth are matched against lines in ocr to find the | ||||
|  | @ -99,7 +99,7 @@ def match_longest_gt_lines(gt_lines: List["Part"], | |||
| 
 | ||||
|     :return: Possible match object. | ||||
|     """ | ||||
|     best_score, best_match, best_gt, best_ocr = -float('inf'), None, None, None | ||||
|     best_score, best_match, best_gt, best_ocr = -float("inf"), None, None, None | ||||
|     if not ocr_lines: | ||||
|         return best_match | ||||
| 
 | ||||
|  | @ -126,10 +126,9 @@ def match_longest_gt_lines(gt_lines: List["Part"], | |||
|     return best_match | ||||
| 
 | ||||
| 
 | ||||
| def match_gt_line(gt_line: "Part", | ||||
|                   ocr_lines: List["Part"], | ||||
|                   coef: "Coefficients") -> Tuple[Optional["Match"], | ||||
|                                                  Optional["Part"]]: | ||||
| def match_gt_line( | ||||
|     gt_line: "Part", ocr_lines: List["Part"], coef: "Coefficients" | ||||
| ) -> Tuple[Optional["Match"], Optional["Part"]]: | ||||
|     """Match the given ground truth line against all the lines in ocr. | ||||
| 
 | ||||
|     Reference: contains steps 3 of the flexible character accuracy algorithm. | ||||
|  | @ -138,19 +137,18 @@ def match_gt_line(gt_line: "Part", | |||
| 
 | ||||
|     :return: Match object and the matched ocr line. | ||||
|     """ | ||||
|     min_penalty = float('inf') | ||||
|     min_penalty = float("inf") | ||||
|     best_match, best_ocr = None, None | ||||
|     for ocr_line in [*ocr_lines]: | ||||
|         match = match_lines(gt_line, ocr_line) | ||||
|         penalty = calculate_penalty(gt_line, ocr_line, match, coef) | ||||
|         if penalty < min_penalty: | ||||
|             min_penalty, best_match, best_ocr = penalty, match, ocr_line | ||||
|         if match: | ||||
|             penalty = calculate_penalty(gt_line, ocr_line, match, coef) | ||||
|             if penalty < min_penalty: | ||||
|                 min_penalty, best_match, best_ocr = penalty, match, ocr_line | ||||
|     return best_match, best_ocr | ||||
| 
 | ||||
| 
 | ||||
| def remove_or_split(original: "Part", | ||||
|                     match: "Part", | ||||
|                     lines: List["Part"]) -> bool: | ||||
| def remove_or_split(original: "Part", match: "Part", lines: List["Part"]) -> bool: | ||||
|     """Removes the matched line or splits it into parts. | ||||
| 
 | ||||
|     Reference: contains step 4 of the flexible character accuracy algorithm. | ||||
|  | @ -187,17 +185,24 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]: | |||
|     if min_length == 0: | ||||
|         return best_match | ||||
|     length_diff = gt_line.length - ocr_line.length | ||||
|     min_edit_dist = float('inf') | ||||
|     min_edit_dist = float("inf") | ||||
| 
 | ||||
|     gt_parts = [(i, gt_line.substring(rel_start=i, rel_end=i + min_length)) | ||||
|                 for i in range(0, max(1, length_diff + 1))] | ||||
|     ocr_parts = [(j, ocr_line.substring(rel_start=j, rel_end=j + min_length)) | ||||
|                  for j in range(0, max(1, -1 * length_diff + 1))] | ||||
|     gt_parts = [ | ||||
|         (i, gt_line.substring(rel_start=i, rel_end=i + min_length)) | ||||
|         for i in range(0, max(1, length_diff + 1)) | ||||
|     ] | ||||
|     ocr_parts = [ | ||||
|         (j, ocr_line.substring(rel_start=j, rel_end=j + min_length)) | ||||
|         for j in range(0, max(1, -1 * length_diff + 1)) | ||||
|     ] | ||||
| 
 | ||||
|     # add full line and empty line match | ||||
|     gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)] | ||||
|     ocr_parts = [*ocr_parts, (0, ocr_line), | ||||
|                  (0, Part(text="", line=gt_line.line, start=gt_line.start))] | ||||
|     ocr_parts = [ | ||||
|         *ocr_parts, | ||||
|         (0, ocr_line), | ||||
|         (0, Part(text="", line=gt_line.line, start=gt_line.start)), | ||||
|     ] | ||||
| 
 | ||||
|     for i, gt_part in gt_parts: | ||||
|         for j, ocr_part in ocr_parts: | ||||
|  | @ -211,8 +216,10 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]: | |||
|         part_length = best_match.gt.length | ||||
|         additional_length = best_match.dist.delete + best_match.dist.replace | ||||
|         for k in range(part_length + 1, part_length + additional_length + 1): | ||||
|             match = distance(gt_line.substring(rel_start=best_i, rel_end=best_i + k), | ||||
|                              ocr_line.substring(rel_start=best_j, rel_end=best_j + k)) | ||||
|             match = distance( | ||||
|                 gt_line.substring(rel_start=best_i, rel_end=best_i + k), | ||||
|                 ocr_line.substring(rel_start=best_j, rel_end=best_j + k), | ||||
|             ) | ||||
|             edit_dist = score_edit_distance(match) | ||||
|             if edit_dist < min_edit_dist: | ||||
|                 min_edit_dist = edit_dist | ||||
|  | @ -247,8 +254,9 @@ def score_edit_distance(match: "Match") -> int: | |||
|     return match.dist.delete + match.dist.insert + 2 * match.dist.replace | ||||
| 
 | ||||
| 
 | ||||
| def calculate_penalty(gt: "Part", ocr: "Part", match: "Match", | ||||
|                       coef: "Coefficients") -> float: | ||||
| def calculate_penalty( | ||||
|     gt: "Part", ocr: "Part", match: "Match", coef: "Coefficients" | ||||
| ) -> float: | ||||
|     """Calculate the penalty for a given match. | ||||
| 
 | ||||
|     For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003. | ||||
|  | @ -262,10 +270,12 @@ def calculate_penalty(gt: "Part", ocr: "Part", match: "Match", | |||
|     if length_diff > 1: | ||||
|         substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start) | ||||
|         offset = length_diff / 2 - abs(substring_pos - length_diff / 2) | ||||
|     return (min_edit_dist * coef.edit_dist | ||||
|             + length_diff * coef.length_diff | ||||
|             + offset * coef.offset | ||||
|             - substring_length * coef.length) | ||||
|     return ( | ||||
|         min_edit_dist * coef.edit_dist | ||||
|         + length_diff * coef.length_diff | ||||
|         + offset * coef.offset | ||||
|         - substring_length * coef.length | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def character_accuracy_for_matches(matches: List["Match"]) -> float: | ||||
|  | @ -274,8 +284,9 @@ def character_accuracy_for_matches(matches: List["Match"]) -> float: | |||
|     See other `character_accuracy` for details. | ||||
| 
 | ||||
|     """ | ||||
|     agg: Counter = reduce(lambda acc, match: acc + Counter(match.dist._asdict()), | ||||
|                           matches, Counter()) | ||||
|     agg: Counter = reduce( | ||||
|         lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() | ||||
|     ) | ||||
| 
 | ||||
|     score = character_accuracy(Distance(**agg)) | ||||
|     return score | ||||
|  | @ -299,9 +310,9 @@ def character_accuracy(edits: "Distance") -> float: | |||
|     chars = edits.match + edits.replace + edits.delete | ||||
|     if not chars and not errors: | ||||
|         # comparison of empty strings is considered a full match | ||||
|         score = 1 | ||||
|         score = 1.0 | ||||
|     else: | ||||
|         score = 1 - errors / chars | ||||
|         score = 1.0 - errors / chars | ||||
|     return score | ||||
| 
 | ||||
| 
 | ||||
|  | @ -315,9 +326,11 @@ def initialize_lines(text: str) -> List["Part"]: | |||
|     :param text: Text to split into lines. | ||||
|     :return: List of sorted line objects. | ||||
|     """ | ||||
|     lines = [Part(text=line, line=i, start=0) | ||||
|              for i, line in enumerate(text.splitlines()) | ||||
|              if len(line) > 0] | ||||
|     lines = [ | ||||
|         Part(text=line, line=i, start=0) | ||||
|         for i, line in enumerate(text.splitlines()) | ||||
|         if len(line) > 0 | ||||
|     ] | ||||
|     lines.sort(key=lambda x: x.length, reverse=True) | ||||
|     return lines | ||||
| 
 | ||||
|  | @ -348,6 +361,7 @@ class Part(NamedTuple): | |||
| 
 | ||||
|     This data object is maintained to be able to reproduce the original text. | ||||
|     """ | ||||
| 
 | ||||
|     text: str = "" | ||||
|     line: int = 0 | ||||
|     start: int = 0 | ||||
|  | @ -392,6 +406,7 @@ class Part(NamedTuple): | |||
| 
 | ||||
| class Distance(NamedTuple): | ||||
|     """Represent distance between two sequences.""" | ||||
| 
 | ||||
|     match: int = 0 | ||||
|     replace: int = 0 | ||||
|     delete: int = 0 | ||||
|  | @ -400,6 +415,7 @@ class Distance(NamedTuple): | |||
| 
 | ||||
| class Match(NamedTuple): | ||||
|     """Represent a calculated match between ground truth and the ocr result.""" | ||||
| 
 | ||||
|     gt: "Part" | ||||
|     ocr: "Part" | ||||
|     dist: "Distance" | ||||
|  | @ -411,6 +427,7 @@ class Coefficients(NamedTuple): | |||
| 
 | ||||
|     See Section 3 in doi:10.1016/j.patrec.2020.02.003 | ||||
|     """ | ||||
| 
 | ||||
|     edit_dist: int = 25 | ||||
|     length_diff: int = 20 | ||||
|     offset: int = 1 | ||||
|  |  | |||
|  | @ -35,37 +35,31 @@ COMPLEX_CASES = [ | |||
| EXTENDED_CASES = [ | ||||
|     # See figure 4 in 10.1016/j.patrec.2020.02.003 | ||||
|     # A: No errors | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      1, 1), | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), 1, 1), | ||||
|     # B: Different ordering of text blocks | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), | ||||
|      1, 1), | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), 1, 1), | ||||
|     # C: Merge across columns | ||||
|     ((0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9), | ||||
|      (0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9), | ||||
|      1, 0.964), | ||||
|     ( | ||||
|         (0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9), | ||||
|         (0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9), | ||||
|         1, | ||||
|         0.964, | ||||
|     ), | ||||
|     # D: Over-segmentation | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9), | ||||
|      1, 0.966), | ||||
|     ( | ||||
|         (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|         (0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9), | ||||
|         1, | ||||
|         0.966, | ||||
|     ), | ||||
|     # E: Part missing | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (0, 1, 2, 3, 4), | ||||
|      1, 0.50), | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4), 1, 0.50), | ||||
|     # E.2: Part missing | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (5, 6, 7, 8, 9), | ||||
|      1, 0.50), | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9), 1, 0.50), | ||||
|     # F: All missing | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), | ||||
|      (), | ||||
|      1, 0), | ||||
|     ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (), 1, 0), | ||||
|     # G: Added parts | ||||
|     ((0, 1, 2, 3, 4), | ||||
|      (0, 1, 2, 3, 4, 11, 5, 6), | ||||
|      1, 0.621), | ||||
|     ((0, 1, 2, 3, 4), (0, 1, 2, 3, 4, 11, 5, 6), 1, 0.621), | ||||
| ] | ||||
| 
 | ||||
| EDIT_ARGS = "gt,ocr,expected_dist" | ||||
|  | @ -73,8 +67,11 @@ EDIT_ARGS = "gt,ocr,expected_dist" | |||
| SIMPLE_EDITS = [ | ||||
|     (Part(text="a"), Part(text="a"), Distance(match=1)), | ||||
|     (Part(text="aaa"), Part(text="aaa"), Distance(match=3)), | ||||
|     (Part(text="abcd"), Part(text="beed"), | ||||
|      Distance(match=2, replace=1, insert=1, delete=1)), | ||||
|     ( | ||||
|         Part(text="abcd"), | ||||
|         Part(text="beed"), | ||||
|         Distance(match=2, replace=1, insert=1, delete=1), | ||||
|     ), | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -83,9 +80,20 @@ def extended_case_to_text(gt, ocr): | |||
| 
 | ||||
|     See figure 4 in 10.1016/j.patrec.2020.02.003 | ||||
|     """ | ||||
|     sentence = ("Eight", "happy", "frogs", "scuba", "dived", | ||||
|                 "Jenny", "chick", "flaps", "white", "wings", | ||||
|                 "", "\n") | ||||
|     sentence = ( | ||||
|         "Eight", | ||||
|         "happy", | ||||
|         "frogs", | ||||
|         "scuba", | ||||
|         "dived", | ||||
|         "Jenny", | ||||
|         "chick", | ||||
|         "flaps", | ||||
|         "white", | ||||
|         "wings", | ||||
|         "", | ||||
|         "\n", | ||||
|     ) | ||||
| 
 | ||||
|     gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n") | ||||
|     ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n") | ||||
|  | @ -98,22 +106,35 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_ | |||
|     assert score == pytest.approx(all_line_score) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("config,ocr", [ | ||||
|     ("Config I", | ||||
|      "1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein" | ||||
|      ), | ||||
|     ("Config II", | ||||
|      "1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\"" | ||||
|      ), | ||||
|     ("Config III", | ||||
|      "Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"" | ||||
|      ), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     "config,ocr", | ||||
|     [ | ||||
|         ( | ||||
|             "Config I", | ||||
|             "1 hav\nnospecial\ntalents.\n" | ||||
|             'I am one\npassionate\ncuriousity."\n' | ||||
|             "Alberto\nEmstein", | ||||
|         ), | ||||
|         ( | ||||
|             "Config II", | ||||
|             '1 hav\nnospecial\ntalents. Alberto\n' | ||||
|             'I am one Emstein\npassionate\ncuriousity."', | ||||
|         ), | ||||
|         ( | ||||
|             "Config III", | ||||
|             'Alberto\nEmstein\n' | ||||
|             '1 hav\nnospecial\ntalents.\n' | ||||
|             'I am one\npassionate\ncuriousity."', | ||||
|         ), | ||||
|     ], | ||||
| ) | ||||
| def test_flexible_character_accuracy(config, ocr): | ||||
|     """Tests from figure 3 in the paper.""" | ||||
|     gt = "\"I have\nno special\ntalent." \ | ||||
|          "\nI am only\npassionately\ncurious.\"" \ | ||||
|          "\nAlbert\nEinstein" | ||||
|     gt = ( | ||||
|         '"I have\nno special\ntalent.\n' | ||||
|         'I am only\npassionately\ncurious."\n' | ||||
|         "Albert\nEinstein" | ||||
|     ) | ||||
|     replacements, inserts, deletes = 3, 5, 7 | ||||
|     chars = len(gt) - gt.count("\n") | ||||
|     assert chars == 68 | ||||
|  | @ -127,21 +148,27 @@ def test_flexible_character_accuracy(config, ocr): | |||
|         replacements += 1 | ||||
|         deletes -= 1 | ||||
| 
 | ||||
|     expected_dist = Distance(match=chars - deletes - replacements, replace=replacements, | ||||
|                              insert=inserts, delete=deletes) | ||||
|     expected_dist = Distance( | ||||
|         match=chars - deletes - replacements, | ||||
|         replace=replacements, | ||||
|         insert=inserts, | ||||
|         delete=deletes, | ||||
|     ) | ||||
|     expected_score = character_accuracy(expected_dist) | ||||
| 
 | ||||
|     result, matches = flexible_character_accuracy(gt, ocr) | ||||
|     agg = reduce(lambda acc, match: acc + Counter(match.dist._asdict()), | ||||
|                  matches, Counter()) | ||||
|     agg = reduce( | ||||
|         lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() | ||||
|     ) | ||||
|     dist = Distance(**agg) | ||||
|     assert dist == expected_dist | ||||
|     assert result == pytest.approx(expected_score, abs=0.0005) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES) | ||||
| def test_flexible_character_accuracy_extended(gt, ocr, first_line_score, | ||||
|                                               all_line_score): | ||||
| def test_flexible_character_accuracy_extended( | ||||
|     gt, ocr, first_line_score, all_line_score | ||||
| ): | ||||
|     """Tests from figure 4 in the paper.""" | ||||
|     gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr) | ||||
|     result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence) | ||||
|  | @ -170,10 +197,13 @@ def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score): | |||
|     assert score == pytest.approx(first_line_score) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(CASE_ARGS, [ | ||||
|     *SIMPLE_CASES, | ||||
|     ("accc", "a\nbb\nccc", 1.0, 1.0), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     CASE_ARGS, | ||||
|     [ | ||||
|         *SIMPLE_CASES, | ||||
|         ("accc", "a\nbb\nccc", 1.0, 1.0), | ||||
|     ], | ||||
| ) | ||||
| def test_match_gt_line(gt, ocr, first_line_score, all_line_score): | ||||
|     coef = Coefficients() | ||||
|     gt_lines = initialize_lines(gt) | ||||
|  | @ -185,15 +215,22 @@ def test_match_gt_line(gt, ocr, first_line_score, all_line_score): | |||
|     assert score == pytest.approx(first_line_score) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("original,match,expected_lines", [ | ||||
|     (Part(), Part(), []), | ||||
|     (Part(text="abc"), Part(), [Part(text="abc")]), | ||||
|     (Part(text="abc"), Part("d"), [Part(text="bc", start=1)]), | ||||
|     (Part(text="abc"), Part("a", start=100), [Part(text="abc")]), | ||||
|     (Part(text="abc"), Part("a"), [Part(text="bc", start=1)]), | ||||
|     (Part(text="abc"), Part("b", start=1), [Part(text="a"), Part(text="c", start=2)]), | ||||
|     (Part(text="abc"), Part("c", start=2), [Part(text="ab")]), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     "original,match,expected_lines", | ||||
|     [ | ||||
|         (Part(), Part(), []), | ||||
|         (Part(text="abc"), Part(), [Part(text="abc")]), | ||||
|         (Part(text="abc"), Part("d"), [Part(text="bc", start=1)]), | ||||
|         (Part(text="abc"), Part("a", start=100), [Part(text="abc")]), | ||||
|         (Part(text="abc"), Part("a"), [Part(text="bc", start=1)]), | ||||
|         ( | ||||
|             Part(text="abc"), | ||||
|             Part("b", start=1), | ||||
|             [Part(text="a"), Part(text="c", start=2)], | ||||
|         ), | ||||
|         (Part(text="abc"), Part("c", start=2), [Part(text="ab")]), | ||||
|     ], | ||||
| ) | ||||
| def test_remove_or_split(original, match, expected_lines): | ||||
|     lines = [original] | ||||
|     splitted = remove_or_split(original, match, lines) | ||||
|  | @ -201,18 +238,29 @@ def test_remove_or_split(original, match, expected_lines): | |||
|     assert lines == expected_lines | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(EDIT_ARGS, [ | ||||
|     *SIMPLE_EDITS, | ||||
|     (Part(text="a"), Part(text="b"), Distance(delete=1)), | ||||
|     (Part(text="aaa"), Part(text="bbb"), Distance(delete=3)), | ||||
|     (Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)), | ||||
|     (Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)), | ||||
|     (Part(text=""), Part(text=""), None), | ||||
|     (Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)), | ||||
|     (Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)), | ||||
|     (Part(text="aaabbbaaaddd"), Part(text="aaabcbaaa"), Distance(match=8, replace=1)), | ||||
|     (Part(text="aaabbbccc"), Part(text="aaabbbdddccc"), Distance(match=9, insert=3)), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     EDIT_ARGS, | ||||
|     [ | ||||
|         *SIMPLE_EDITS, | ||||
|         (Part(text="a"), Part(text="b"), Distance(delete=1)), | ||||
|         (Part(text="aaa"), Part(text="bbb"), Distance(delete=3)), | ||||
|         (Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)), | ||||
|         (Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)), | ||||
|         (Part(text=""), Part(text=""), None), | ||||
|         (Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)), | ||||
|         (Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)), | ||||
|         ( | ||||
|             Part(text="aaabbbaaaddd"), | ||||
|             Part(text="aaabcbaaa"), | ||||
|             Distance(match=8, replace=1), | ||||
|         ), | ||||
|         ( | ||||
|             Part(text="aaabbbccc"), | ||||
|             Part(text="aaabbbdddccc"), | ||||
|             Distance(match=9, insert=3), | ||||
|         ), | ||||
|     ], | ||||
| ) | ||||
| def test_match_lines(gt, ocr, expected_dist): | ||||
|     match = match_lines(gt, ocr) | ||||
|     if not expected_dist: | ||||
|  | @ -223,14 +271,17 @@ def test_match_lines(gt, ocr, expected_dist): | |||
|         assert match.dist == expected_dist | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize(EDIT_ARGS, [ | ||||
|     *SIMPLE_EDITS, | ||||
|     (Part(text="").substring(), Part(text=""), Distance()), | ||||
|     (Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)), | ||||
|     (Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)), | ||||
|     (Part(text="a"), Part(text="b"), Distance(replace=1)), | ||||
|     (Part(text="aaa"), Part(text="bbb"), Distance(replace=3)), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     EDIT_ARGS, | ||||
|     [ | ||||
|         *SIMPLE_EDITS, | ||||
|         (Part(text="").substring(), Part(text=""), Distance()), | ||||
|         (Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)), | ||||
|         (Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)), | ||||
|         (Part(text="a"), Part(text="b"), Distance(replace=1)), | ||||
|         (Part(text="aaa"), Part(text="bbb"), Distance(replace=3)), | ||||
|     ], | ||||
| ) | ||||
| def test_distance(gt, ocr, expected_dist): | ||||
|     match = distance(gt, ocr) | ||||
|     assert match.gt == gt | ||||
|  | @ -238,46 +289,77 @@ def test_distance(gt, ocr, expected_dist): | |||
|     assert match.dist == expected_dist | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("matches,expected_dist", [ | ||||
|     ([], 1), | ||||
|     ([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1), | ||||
|     ([Match(gt=Part(text="abee"), ocr=Part("ac"), | ||||
|             dist=Distance(match=1, replace=1, delete=2), ops=[]), | ||||
|       Match(gt=Part(text="cd"), ocr=Part("ceff"), | ||||
|             dist=Distance(match=1, replace=1, insert=2), ops=[])], | ||||
|      1 - 6 / 6), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     "matches,expected_dist", | ||||
|     [ | ||||
|         ([], 1), | ||||
|         ([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1), | ||||
|         ( | ||||
|             [ | ||||
|                 Match( | ||||
|                     gt=Part(text="abee"), | ||||
|                     ocr=Part("ac"), | ||||
|                     dist=Distance(match=1, replace=1, delete=2), | ||||
|                     ops=[], | ||||
|                 ), | ||||
|                 Match( | ||||
|                     gt=Part(text="cd"), | ||||
|                     ocr=Part("ceff"), | ||||
|                     dist=Distance(match=1, replace=1, insert=2), | ||||
|                     ops=[], | ||||
|                 ), | ||||
|             ], | ||||
|             1 - 6 / 6, | ||||
|         ), | ||||
|     ], | ||||
| ) | ||||
| def test_character_accuracy_matches(matches, expected_dist): | ||||
|     assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("dist,expected_dist", [ | ||||
|     (Distance(), 1), | ||||
|     (Distance(match=1), 1), | ||||
|     (Distance(replace=1), 0), | ||||
|     (Distance(match=1, insert=1), 0), | ||||
|     (Distance(match=1, insert=2), 1 - 2 / 1), | ||||
|     (Distance(match=2, insert=1), 0.5), | ||||
|     (Distance(match=1, delete=1), 0.5), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     "dist,expected_dist", | ||||
|     [ | ||||
|         (Distance(), 1), | ||||
|         (Distance(match=1), 1), | ||||
|         (Distance(replace=1), 0), | ||||
|         (Distance(match=1, insert=1), 0), | ||||
|         (Distance(match=1, insert=2), 1 - 2 / 1), | ||||
|         (Distance(match=2, insert=1), 0.5), | ||||
|         (Distance(match=1, delete=1), 0.5), | ||||
|     ], | ||||
| ) | ||||
| def test_character_accuracy_dist(dist, expected_dist): | ||||
|     assert character_accuracy(dist) == pytest.approx(expected_dist) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("line,subline,expected_rest", [ | ||||
|     (Part(), Part(), []), | ||||
|     (Part("aaa bbb"), Part("aaa bbb"), []), | ||||
|     (Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]), | ||||
|     (Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]), | ||||
|     (Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]), | ||||
|     (Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]), | ||||
|     (Part("aaa bbb ccc"), Part("bbb", start=4), [Part("aaa "), Part(" ccc", start=7)]), | ||||
|     (Part("aaa bbb ccc", start=3), Part("bbb", start=7), | ||||
|      [Part("aaa ", start=3), Part(" ccc", start=10)]), | ||||
|     (Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]), | ||||
|     (Part("aaa bbb", start=3), Part(" ", start=6), | ||||
|      [Part("aaa", start=3), Part("bbb", start=7)]), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     "line,subline,expected_rest", | ||||
|     [ | ||||
|         (Part(), Part(), []), | ||||
|         (Part("aaa bbb"), Part("aaa bbb"), []), | ||||
|         (Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]), | ||||
|         (Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]), | ||||
|         (Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]), | ||||
|         (Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]), | ||||
|         ( | ||||
|             Part("aaa bbb ccc"), | ||||
|             Part("bbb", start=4), | ||||
|             [Part("aaa "), Part(" ccc", start=7)], | ||||
|         ), | ||||
|         ( | ||||
|             Part("aaa bbb ccc", start=3), | ||||
|             Part("bbb", start=7), | ||||
|             [Part("aaa ", start=3), Part(" ccc", start=10)], | ||||
|         ), | ||||
|         (Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]), | ||||
|         ( | ||||
|             Part("aaa bbb", start=3), | ||||
|             Part(" ", start=6), | ||||
|             [Part("aaa", start=3), Part("bbb", start=7)], | ||||
|         ), | ||||
|     ], | ||||
| ) | ||||
| def test_split_line(line, subline, expected_rest): | ||||
|     rest = line.split(subline) | ||||
|     assert len(rest) == len(expected_rest) | ||||
|  | @ -300,12 +382,15 @@ def test_combine_lines(): | |||
|     assert False | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("line,start,end,expected", [ | ||||
|     (Part(text=""), 0, None, Part(text="")), | ||||
|     (Part(text="a"), 0, None, Part(text="a")), | ||||
|     (Part(text="ab"), 0, 1, Part(text="a")), | ||||
|     (Part(text="abc"), 0, -1, Part(text="ab")), | ||||
|     (Part(text="ab"), 1, None, Part(text="b", start=1)), | ||||
| ]) | ||||
| @pytest.mark.parametrize( | ||||
|     "line,start,end,expected", | ||||
|     [ | ||||
|         (Part(text=""), 0, None, Part(text="")), | ||||
|         (Part(text="a"), 0, None, Part(text="a")), | ||||
|         (Part(text="ab"), 0, 1, Part(text="a")), | ||||
|         (Part(text="abc"), 0, -1, Part(text="ab")), | ||||
|         (Part(text="ab"), 1, None, Part(text="b", start=1)), | ||||
|     ], | ||||
| ) | ||||
| def test_line_substring(line, start, end, expected): | ||||
|     assert line.substring(rel_start=start, rel_end=end) == expected | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue