diff --git a/qurator/dinglehopper/flexible_character_accuracy.py b/qurator/dinglehopper/flexible_character_accuracy.py index 2b9a56f..e81ef54 100644 --- a/qurator/dinglehopper/flexible_character_accuracy.py +++ b/qurator/dinglehopper/flexible_character_accuracy.py @@ -29,20 +29,16 @@ def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"] :return: Score between 0 and 1 and match objects. """ - best_score = -float('inf') + best_score = -float("inf") best_matches = [] # TODO: this should be configurable - combinations = product(range(15, 31, 5), - range(0, 24, 3), - range(0, 4, 1), - range(0, 6, 1)) + combinations = product( + range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1) + ) # TODO: place to parallelize the algorithm for (edit_dist, length_diff, offset, length) in combinations: coef = Coefficients( - edit_dist=edit_dist, - length_diff=length_diff, - offset=offset, - length=length + edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length ) # Steps 1 - 6 of the flexible character accuracy algorithm. matches = match_with_coefficients(gt, ocr, coef) @@ -79,17 +75,21 @@ def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Ma # Step 6 of the flexible character accuracy algorithm. # remaining lines are considered as deletes and inserts - deletes = [distance(line, Part(text="", line=line.line, start=line.start)) - for line in gt_lines] - inserts = [distance(Part(text="", line=line.line, start=line.start), line) - for line in ocr_lines] + deletes = [ + distance(line, Part(text="", line=line.line, start=line.start)) + for line in gt_lines + ] + inserts = [ + distance(Part(text="", line=line.line, start=line.start), line) + for line in ocr_lines + ] return [*matches, *deletes, *inserts] -def match_longest_gt_lines(gt_lines: List["Part"], - ocr_lines: List["Part"], - coef: "Coefficients") -> Optional["Match"]: +def match_longest_gt_lines( + gt_lines: List["Part"], ocr_lines: List["Part"], coef: "Coefficients" +) -> Optional["Match"]: """Find the best match for the longest line(s) in ground truth. The longest lines in ground truth are matched against lines in ocr to find the @@ -99,7 +99,7 @@ def match_longest_gt_lines(gt_lines: List["Part"], :return: Possible match object. """ - best_score, best_match, best_gt, best_ocr = -float('inf'), None, None, None + best_score, best_match, best_gt, best_ocr = -float("inf"), None, None, None if not ocr_lines: return best_match @@ -126,10 +126,9 @@ def match_longest_gt_lines(gt_lines: List["Part"], return best_match -def match_gt_line(gt_line: "Part", - ocr_lines: List["Part"], - coef: "Coefficients") -> Tuple[Optional["Match"], - Optional["Part"]]: +def match_gt_line( + gt_line: "Part", ocr_lines: List["Part"], coef: "Coefficients" +) -> Tuple[Optional["Match"], Optional["Part"]]: """Match the given ground truth line against all the lines in ocr. Reference: contains steps 3 of the flexible character accuracy algorithm. @@ -138,19 +137,18 @@ def match_gt_line(gt_line: "Part", :return: Match object and the matched ocr line. """ - min_penalty = float('inf') + min_penalty = float("inf") best_match, best_ocr = None, None for ocr_line in [*ocr_lines]: match = match_lines(gt_line, ocr_line) - penalty = calculate_penalty(gt_line, ocr_line, match, coef) - if penalty < min_penalty: - min_penalty, best_match, best_ocr = penalty, match, ocr_line + if match: + penalty = calculate_penalty(gt_line, ocr_line, match, coef) + if penalty < min_penalty: + min_penalty, best_match, best_ocr = penalty, match, ocr_line return best_match, best_ocr -def remove_or_split(original: "Part", - match: "Part", - lines: List["Part"]) -> bool: +def remove_or_split(original: "Part", match: "Part", lines: List["Part"]) -> bool: """Removes the matched line or splits it into parts. Reference: contains step 4 of the flexible character accuracy algorithm. @@ -187,17 +185,24 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]: if min_length == 0: return best_match length_diff = gt_line.length - ocr_line.length - min_edit_dist = float('inf') + min_edit_dist = float("inf") - gt_parts = [(i, gt_line.substring(rel_start=i, rel_end=i + min_length)) - for i in range(0, max(1, length_diff + 1))] - ocr_parts = [(j, ocr_line.substring(rel_start=j, rel_end=j + min_length)) - for j in range(0, max(1, -1 * length_diff + 1))] + gt_parts = [ + (i, gt_line.substring(rel_start=i, rel_end=i + min_length)) + for i in range(0, max(1, length_diff + 1)) + ] + ocr_parts = [ + (j, ocr_line.substring(rel_start=j, rel_end=j + min_length)) + for j in range(0, max(1, -1 * length_diff + 1)) + ] # add full line and empty line match gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)] - ocr_parts = [*ocr_parts, (0, ocr_line), - (0, Part(text="", line=gt_line.line, start=gt_line.start))] + ocr_parts = [ + *ocr_parts, + (0, ocr_line), + (0, Part(text="", line=gt_line.line, start=gt_line.start)), + ] for i, gt_part in gt_parts: for j, ocr_part in ocr_parts: @@ -211,8 +216,10 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]: part_length = best_match.gt.length additional_length = best_match.dist.delete + best_match.dist.replace for k in range(part_length + 1, part_length + additional_length + 1): - match = distance(gt_line.substring(rel_start=best_i, rel_end=best_i + k), - ocr_line.substring(rel_start=best_j, rel_end=best_j + k)) + match = distance( + gt_line.substring(rel_start=best_i, rel_end=best_i + k), + ocr_line.substring(rel_start=best_j, rel_end=best_j + k), + ) edit_dist = score_edit_distance(match) if edit_dist < min_edit_dist: min_edit_dist = edit_dist @@ -247,8 +254,9 @@ def score_edit_distance(match: "Match") -> int: return match.dist.delete + match.dist.insert + 2 * match.dist.replace -def calculate_penalty(gt: "Part", ocr: "Part", match: "Match", - coef: "Coefficients") -> float: +def calculate_penalty( + gt: "Part", ocr: "Part", match: "Match", coef: "Coefficients" +) -> float: """Calculate the penalty for a given match. For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003. @@ -262,10 +270,12 @@ def calculate_penalty(gt: "Part", ocr: "Part", match: "Match", if length_diff > 1: substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start) offset = length_diff / 2 - abs(substring_pos - length_diff / 2) - return (min_edit_dist * coef.edit_dist - + length_diff * coef.length_diff - + offset * coef.offset - - substring_length * coef.length) + return ( + min_edit_dist * coef.edit_dist + + length_diff * coef.length_diff + + offset * coef.offset + - substring_length * coef.length + ) def character_accuracy_for_matches(matches: List["Match"]) -> float: @@ -274,8 +284,9 @@ def character_accuracy_for_matches(matches: List["Match"]) -> float: See other `character_accuracy` for details. """ - agg: Counter = reduce(lambda acc, match: acc + Counter(match.dist._asdict()), - matches, Counter()) + agg: Counter = reduce( + lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() + ) score = character_accuracy(Distance(**agg)) return score @@ -299,9 +310,9 @@ def character_accuracy(edits: "Distance") -> float: chars = edits.match + edits.replace + edits.delete if not chars and not errors: # comparison of empty strings is considered a full match - score = 1 + score = 1.0 else: - score = 1 - errors / chars + score = 1.0 - errors / chars return score @@ -315,9 +326,11 @@ def initialize_lines(text: str) -> List["Part"]: :param text: Text to split into lines. :return: List of sorted line objects. """ - lines = [Part(text=line, line=i, start=0) - for i, line in enumerate(text.splitlines()) - if len(line) > 0] + lines = [ + Part(text=line, line=i, start=0) + for i, line in enumerate(text.splitlines()) + if len(line) > 0 + ] lines.sort(key=lambda x: x.length, reverse=True) return lines @@ -348,6 +361,7 @@ class Part(NamedTuple): This data object is maintained to be able to reproduce the original text. """ + text: str = "" line: int = 0 start: int = 0 @@ -392,6 +406,7 @@ class Part(NamedTuple): class Distance(NamedTuple): """Represent distance between two sequences.""" + match: int = 0 replace: int = 0 delete: int = 0 @@ -400,6 +415,7 @@ class Distance(NamedTuple): class Match(NamedTuple): """Represent a calculated match between ground truth and the ocr result.""" + gt: "Part" ocr: "Part" dist: "Distance" @@ -411,6 +427,7 @@ class Coefficients(NamedTuple): See Section 3 in doi:10.1016/j.patrec.2020.02.003 """ + edit_dist: int = 25 length_diff: int = 20 offset: int = 1 diff --git a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py index 0126696..dfcb1f7 100644 --- a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py +++ b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py @@ -35,37 +35,31 @@ COMPLEX_CASES = [ EXTENDED_CASES = [ # See figure 4 in 10.1016/j.patrec.2020.02.003 # A: No errors - ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), - (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), - 1, 1), + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), 1, 1), # B: Different ordering of text blocks - ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), - (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), - 1, 1), + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), 1, 1), # C: Merge across columns - ((0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9), - (0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9), - 1, 0.964), + ( + (0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9), + (0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9), + 1, + 0.964, + ), # D: Over-segmentation - ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), - (0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9), - 1, 0.966), + ( + (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), + (0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9), + 1, + 0.966, + ), # E: Part missing - ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), - (0, 1, 2, 3, 4), - 1, 0.50), + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4), 1, 0.50), # E.2: Part missing - ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), - (5, 6, 7, 8, 9), - 1, 0.50), + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9), 1, 0.50), # F: All missing - ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), - (), - 1, 0), + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (), 1, 0), # G: Added parts - ((0, 1, 2, 3, 4), - (0, 1, 2, 3, 4, 11, 5, 6), - 1, 0.621), + ((0, 1, 2, 3, 4), (0, 1, 2, 3, 4, 11, 5, 6), 1, 0.621), ] EDIT_ARGS = "gt,ocr,expected_dist" @@ -73,8 +67,11 @@ EDIT_ARGS = "gt,ocr,expected_dist" SIMPLE_EDITS = [ (Part(text="a"), Part(text="a"), Distance(match=1)), (Part(text="aaa"), Part(text="aaa"), Distance(match=3)), - (Part(text="abcd"), Part(text="beed"), - Distance(match=2, replace=1, insert=1, delete=1)), + ( + Part(text="abcd"), + Part(text="beed"), + Distance(match=2, replace=1, insert=1, delete=1), + ), ] @@ -83,9 +80,20 @@ def extended_case_to_text(gt, ocr): See figure 4 in 10.1016/j.patrec.2020.02.003 """ - sentence = ("Eight", "happy", "frogs", "scuba", "dived", - "Jenny", "chick", "flaps", "white", "wings", - "", "\n") + sentence = ( + "Eight", + "happy", + "frogs", + "scuba", + "dived", + "Jenny", + "chick", + "flaps", + "white", + "wings", + "", + "\n", + ) gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n") ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n") @@ -98,22 +106,35 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_ assert score == pytest.approx(all_line_score) -@pytest.mark.parametrize("config,ocr", [ - ("Config I", - "1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein" - ), - ("Config II", - "1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\"" - ), - ("Config III", - "Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"" - ), -]) +@pytest.mark.parametrize( + "config,ocr", + [ + ( + "Config I", + "1 hav\nnospecial\ntalents.\n" + 'I am one\npassionate\ncuriousity."\n' + "Alberto\nEmstein", + ), + ( + "Config II", + '1 hav\nnospecial\ntalents. Alberto\n' + 'I am one Emstein\npassionate\ncuriousity."', + ), + ( + "Config III", + 'Alberto\nEmstein\n' + '1 hav\nnospecial\ntalents.\n' + 'I am one\npassionate\ncuriousity."', + ), + ], +) def test_flexible_character_accuracy(config, ocr): """Tests from figure 3 in the paper.""" - gt = "\"I have\nno special\ntalent." \ - "\nI am only\npassionately\ncurious.\"" \ - "\nAlbert\nEinstein" + gt = ( + '"I have\nno special\ntalent.\n' + 'I am only\npassionately\ncurious."\n' + "Albert\nEinstein" + ) replacements, inserts, deletes = 3, 5, 7 chars = len(gt) - gt.count("\n") assert chars == 68 @@ -127,21 +148,27 @@ def test_flexible_character_accuracy(config, ocr): replacements += 1 deletes -= 1 - expected_dist = Distance(match=chars - deletes - replacements, replace=replacements, - insert=inserts, delete=deletes) + expected_dist = Distance( + match=chars - deletes - replacements, + replace=replacements, + insert=inserts, + delete=deletes, + ) expected_score = character_accuracy(expected_dist) result, matches = flexible_character_accuracy(gt, ocr) - agg = reduce(lambda acc, match: acc + Counter(match.dist._asdict()), - matches, Counter()) + agg = reduce( + lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() + ) dist = Distance(**agg) assert dist == expected_dist assert result == pytest.approx(expected_score, abs=0.0005) @pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES) -def test_flexible_character_accuracy_extended(gt, ocr, first_line_score, - all_line_score): +def test_flexible_character_accuracy_extended( + gt, ocr, first_line_score, all_line_score +): """Tests from figure 4 in the paper.""" gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr) result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence) @@ -170,10 +197,13 @@ def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score): assert score == pytest.approx(first_line_score) -@pytest.mark.parametrize(CASE_ARGS, [ - *SIMPLE_CASES, - ("accc", "a\nbb\nccc", 1.0, 1.0), -]) +@pytest.mark.parametrize( + CASE_ARGS, + [ + *SIMPLE_CASES, + ("accc", "a\nbb\nccc", 1.0, 1.0), + ], +) def test_match_gt_line(gt, ocr, first_line_score, all_line_score): coef = Coefficients() gt_lines = initialize_lines(gt) @@ -185,15 +215,22 @@ def test_match_gt_line(gt, ocr, first_line_score, all_line_score): assert score == pytest.approx(first_line_score) -@pytest.mark.parametrize("original,match,expected_lines", [ - (Part(), Part(), []), - (Part(text="abc"), Part(), [Part(text="abc")]), - (Part(text="abc"), Part("d"), [Part(text="bc", start=1)]), - (Part(text="abc"), Part("a", start=100), [Part(text="abc")]), - (Part(text="abc"), Part("a"), [Part(text="bc", start=1)]), - (Part(text="abc"), Part("b", start=1), [Part(text="a"), Part(text="c", start=2)]), - (Part(text="abc"), Part("c", start=2), [Part(text="ab")]), -]) +@pytest.mark.parametrize( + "original,match,expected_lines", + [ + (Part(), Part(), []), + (Part(text="abc"), Part(), [Part(text="abc")]), + (Part(text="abc"), Part("d"), [Part(text="bc", start=1)]), + (Part(text="abc"), Part("a", start=100), [Part(text="abc")]), + (Part(text="abc"), Part("a"), [Part(text="bc", start=1)]), + ( + Part(text="abc"), + Part("b", start=1), + [Part(text="a"), Part(text="c", start=2)], + ), + (Part(text="abc"), Part("c", start=2), [Part(text="ab")]), + ], +) def test_remove_or_split(original, match, expected_lines): lines = [original] splitted = remove_or_split(original, match, lines) @@ -201,18 +238,29 @@ def test_remove_or_split(original, match, expected_lines): assert lines == expected_lines -@pytest.mark.parametrize(EDIT_ARGS, [ - *SIMPLE_EDITS, - (Part(text="a"), Part(text="b"), Distance(delete=1)), - (Part(text="aaa"), Part(text="bbb"), Distance(delete=3)), - (Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)), - (Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)), - (Part(text=""), Part(text=""), None), - (Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)), - (Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)), - (Part(text="aaabbbaaaddd"), Part(text="aaabcbaaa"), Distance(match=8, replace=1)), - (Part(text="aaabbbccc"), Part(text="aaabbbdddccc"), Distance(match=9, insert=3)), -]) +@pytest.mark.parametrize( + EDIT_ARGS, + [ + *SIMPLE_EDITS, + (Part(text="a"), Part(text="b"), Distance(delete=1)), + (Part(text="aaa"), Part(text="bbb"), Distance(delete=3)), + (Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)), + (Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)), + (Part(text=""), Part(text=""), None), + (Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)), + (Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)), + ( + Part(text="aaabbbaaaddd"), + Part(text="aaabcbaaa"), + Distance(match=8, replace=1), + ), + ( + Part(text="aaabbbccc"), + Part(text="aaabbbdddccc"), + Distance(match=9, insert=3), + ), + ], +) def test_match_lines(gt, ocr, expected_dist): match = match_lines(gt, ocr) if not expected_dist: @@ -223,14 +271,17 @@ def test_match_lines(gt, ocr, expected_dist): assert match.dist == expected_dist -@pytest.mark.parametrize(EDIT_ARGS, [ - *SIMPLE_EDITS, - (Part(text="").substring(), Part(text=""), Distance()), - (Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)), - (Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)), - (Part(text="a"), Part(text="b"), Distance(replace=1)), - (Part(text="aaa"), Part(text="bbb"), Distance(replace=3)), -]) +@pytest.mark.parametrize( + EDIT_ARGS, + [ + *SIMPLE_EDITS, + (Part(text="").substring(), Part(text=""), Distance()), + (Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)), + (Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)), + (Part(text="a"), Part(text="b"), Distance(replace=1)), + (Part(text="aaa"), Part(text="bbb"), Distance(replace=3)), + ], +) def test_distance(gt, ocr, expected_dist): match = distance(gt, ocr) assert match.gt == gt @@ -238,46 +289,77 @@ def test_distance(gt, ocr, expected_dist): assert match.dist == expected_dist -@pytest.mark.parametrize("matches,expected_dist", [ - ([], 1), - ([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1), - ([Match(gt=Part(text="abee"), ocr=Part("ac"), - dist=Distance(match=1, replace=1, delete=2), ops=[]), - Match(gt=Part(text="cd"), ocr=Part("ceff"), - dist=Distance(match=1, replace=1, insert=2), ops=[])], - 1 - 6 / 6), -]) +@pytest.mark.parametrize( + "matches,expected_dist", + [ + ([], 1), + ([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1), + ( + [ + Match( + gt=Part(text="abee"), + ocr=Part("ac"), + dist=Distance(match=1, replace=1, delete=2), + ops=[], + ), + Match( + gt=Part(text="cd"), + ocr=Part("ceff"), + dist=Distance(match=1, replace=1, insert=2), + ops=[], + ), + ], + 1 - 6 / 6, + ), + ], +) def test_character_accuracy_matches(matches, expected_dist): assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist) -@pytest.mark.parametrize("dist,expected_dist", [ - (Distance(), 1), - (Distance(match=1), 1), - (Distance(replace=1), 0), - (Distance(match=1, insert=1), 0), - (Distance(match=1, insert=2), 1 - 2 / 1), - (Distance(match=2, insert=1), 0.5), - (Distance(match=1, delete=1), 0.5), -]) +@pytest.mark.parametrize( + "dist,expected_dist", + [ + (Distance(), 1), + (Distance(match=1), 1), + (Distance(replace=1), 0), + (Distance(match=1, insert=1), 0), + (Distance(match=1, insert=2), 1 - 2 / 1), + (Distance(match=2, insert=1), 0.5), + (Distance(match=1, delete=1), 0.5), + ], +) def test_character_accuracy_dist(dist, expected_dist): assert character_accuracy(dist) == pytest.approx(expected_dist) -@pytest.mark.parametrize("line,subline,expected_rest", [ - (Part(), Part(), []), - (Part("aaa bbb"), Part("aaa bbb"), []), - (Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]), - (Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]), - (Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]), - (Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]), - (Part("aaa bbb ccc"), Part("bbb", start=4), [Part("aaa "), Part(" ccc", start=7)]), - (Part("aaa bbb ccc", start=3), Part("bbb", start=7), - [Part("aaa ", start=3), Part(" ccc", start=10)]), - (Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]), - (Part("aaa bbb", start=3), Part(" ", start=6), - [Part("aaa", start=3), Part("bbb", start=7)]), -]) +@pytest.mark.parametrize( + "line,subline,expected_rest", + [ + (Part(), Part(), []), + (Part("aaa bbb"), Part("aaa bbb"), []), + (Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]), + (Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]), + (Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]), + (Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]), + ( + Part("aaa bbb ccc"), + Part("bbb", start=4), + [Part("aaa "), Part(" ccc", start=7)], + ), + ( + Part("aaa bbb ccc", start=3), + Part("bbb", start=7), + [Part("aaa ", start=3), Part(" ccc", start=10)], + ), + (Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]), + ( + Part("aaa bbb", start=3), + Part(" ", start=6), + [Part("aaa", start=3), Part("bbb", start=7)], + ), + ], +) def test_split_line(line, subline, expected_rest): rest = line.split(subline) assert len(rest) == len(expected_rest) @@ -300,12 +382,15 @@ def test_combine_lines(): assert False -@pytest.mark.parametrize("line,start,end,expected", [ - (Part(text=""), 0, None, Part(text="")), - (Part(text="a"), 0, None, Part(text="a")), - (Part(text="ab"), 0, 1, Part(text="a")), - (Part(text="abc"), 0, -1, Part(text="ab")), - (Part(text="ab"), 1, None, Part(text="b", start=1)), -]) +@pytest.mark.parametrize( + "line,start,end,expected", + [ + (Part(text=""), 0, None, Part(text="")), + (Part(text="a"), 0, None, Part(text="a")), + (Part(text="ab"), 0, 1, Part(text="a")), + (Part(text="abc"), 0, -1, Part(text="ab")), + (Part(text="ab"), 1, None, Part(text="b", start=1)), + ], +) def test_line_substring(line, start, end, expected): assert line.substring(rel_start=start, rel_end=end) == expected