Reformat using black

pull/47/head
Benjamin Rosemann 5 years ago
parent 5277593bdb
commit 2a215a1062

@ -29,20 +29,16 @@ def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"]
:return: Score between 0 and 1 and match objects. :return: Score between 0 and 1 and match objects.
""" """
best_score = -float('inf') best_score = -float("inf")
best_matches = [] best_matches = []
# TODO: this should be configurable # TODO: this should be configurable
combinations = product(range(15, 31, 5), combinations = product(
range(0, 24, 3), range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
range(0, 4, 1), )
range(0, 6, 1))
# TODO: place to parallelize the algorithm # TODO: place to parallelize the algorithm
for (edit_dist, length_diff, offset, length) in combinations: for (edit_dist, length_diff, offset, length) in combinations:
coef = Coefficients( coef = Coefficients(
edit_dist=edit_dist, edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
length_diff=length_diff,
offset=offset,
length=length
) )
# Steps 1 - 6 of the flexible character accuracy algorithm. # Steps 1 - 6 of the flexible character accuracy algorithm.
matches = match_with_coefficients(gt, ocr, coef) matches = match_with_coefficients(gt, ocr, coef)
@ -79,17 +75,21 @@ def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Ma
# Step 6 of the flexible character accuracy algorithm. # Step 6 of the flexible character accuracy algorithm.
# remaining lines are considered as deletes and inserts # remaining lines are considered as deletes and inserts
deletes = [distance(line, Part(text="", line=line.line, start=line.start)) deletes = [
for line in gt_lines] distance(line, Part(text="", line=line.line, start=line.start))
inserts = [distance(Part(text="", line=line.line, start=line.start), line) for line in gt_lines
for line in ocr_lines] ]
inserts = [
distance(Part(text="", line=line.line, start=line.start), line)
for line in ocr_lines
]
return [*matches, *deletes, *inserts] return [*matches, *deletes, *inserts]
def match_longest_gt_lines(gt_lines: List["Part"], def match_longest_gt_lines(
ocr_lines: List["Part"], gt_lines: List["Part"], ocr_lines: List["Part"], coef: "Coefficients"
coef: "Coefficients") -> Optional["Match"]: ) -> Optional["Match"]:
"""Find the best match for the longest line(s) in ground truth. """Find the best match for the longest line(s) in ground truth.
The longest lines in ground truth are matched against lines in ocr to find the The longest lines in ground truth are matched against lines in ocr to find the
@ -99,7 +99,7 @@ def match_longest_gt_lines(gt_lines: List["Part"],
:return: Possible match object. :return: Possible match object.
""" """
best_score, best_match, best_gt, best_ocr = -float('inf'), None, None, None best_score, best_match, best_gt, best_ocr = -float("inf"), None, None, None
if not ocr_lines: if not ocr_lines:
return best_match return best_match
@ -126,10 +126,9 @@ def match_longest_gt_lines(gt_lines: List["Part"],
return best_match return best_match
def match_gt_line(gt_line: "Part", def match_gt_line(
ocr_lines: List["Part"], gt_line: "Part", ocr_lines: List["Part"], coef: "Coefficients"
coef: "Coefficients") -> Tuple[Optional["Match"], ) -> Tuple[Optional["Match"], Optional["Part"]]:
Optional["Part"]]:
"""Match the given ground truth line against all the lines in ocr. """Match the given ground truth line against all the lines in ocr.
Reference: contains steps 3 of the flexible character accuracy algorithm. Reference: contains steps 3 of the flexible character accuracy algorithm.
@ -138,19 +137,18 @@ def match_gt_line(gt_line: "Part",
:return: Match object and the matched ocr line. :return: Match object and the matched ocr line.
""" """
min_penalty = float('inf') min_penalty = float("inf")
best_match, best_ocr = None, None best_match, best_ocr = None, None
for ocr_line in [*ocr_lines]: for ocr_line in [*ocr_lines]:
match = match_lines(gt_line, ocr_line) match = match_lines(gt_line, ocr_line)
if match:
penalty = calculate_penalty(gt_line, ocr_line, match, coef) penalty = calculate_penalty(gt_line, ocr_line, match, coef)
if penalty < min_penalty: if penalty < min_penalty:
min_penalty, best_match, best_ocr = penalty, match, ocr_line min_penalty, best_match, best_ocr = penalty, match, ocr_line
return best_match, best_ocr return best_match, best_ocr
def remove_or_split(original: "Part", def remove_or_split(original: "Part", match: "Part", lines: List["Part"]) -> bool:
match: "Part",
lines: List["Part"]) -> bool:
"""Removes the matched line or splits it into parts. """Removes the matched line or splits it into parts.
Reference: contains step 4 of the flexible character accuracy algorithm. Reference: contains step 4 of the flexible character accuracy algorithm.
@ -187,17 +185,24 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
if min_length == 0: if min_length == 0:
return best_match return best_match
length_diff = gt_line.length - ocr_line.length length_diff = gt_line.length - ocr_line.length
min_edit_dist = float('inf') min_edit_dist = float("inf")
gt_parts = [(i, gt_line.substring(rel_start=i, rel_end=i + min_length)) gt_parts = [
for i in range(0, max(1, length_diff + 1))] (i, gt_line.substring(rel_start=i, rel_end=i + min_length))
ocr_parts = [(j, ocr_line.substring(rel_start=j, rel_end=j + min_length)) for i in range(0, max(1, length_diff + 1))
for j in range(0, max(1, -1 * length_diff + 1))] ]
ocr_parts = [
(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
for j in range(0, max(1, -1 * length_diff + 1))
]
# add full line and empty line match # add full line and empty line match
gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)] gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)]
ocr_parts = [*ocr_parts, (0, ocr_line), ocr_parts = [
(0, Part(text="", line=gt_line.line, start=gt_line.start))] *ocr_parts,
(0, ocr_line),
(0, Part(text="", line=gt_line.line, start=gt_line.start)),
]
for i, gt_part in gt_parts: for i, gt_part in gt_parts:
for j, ocr_part in ocr_parts: for j, ocr_part in ocr_parts:
@ -211,8 +216,10 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
part_length = best_match.gt.length part_length = best_match.gt.length
additional_length = best_match.dist.delete + best_match.dist.replace additional_length = best_match.dist.delete + best_match.dist.replace
for k in range(part_length + 1, part_length + additional_length + 1): for k in range(part_length + 1, part_length + additional_length + 1):
match = distance(gt_line.substring(rel_start=best_i, rel_end=best_i + k), match = distance(
ocr_line.substring(rel_start=best_j, rel_end=best_j + k)) gt_line.substring(rel_start=best_i, rel_end=best_i + k),
ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
)
edit_dist = score_edit_distance(match) edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist: if edit_dist < min_edit_dist:
min_edit_dist = edit_dist min_edit_dist = edit_dist
@ -247,8 +254,9 @@ def score_edit_distance(match: "Match") -> int:
return match.dist.delete + match.dist.insert + 2 * match.dist.replace return match.dist.delete + match.dist.insert + 2 * match.dist.replace
def calculate_penalty(gt: "Part", ocr: "Part", match: "Match", def calculate_penalty(
coef: "Coefficients") -> float: gt: "Part", ocr: "Part", match: "Match", coef: "Coefficients"
) -> float:
"""Calculate the penalty for a given match. """Calculate the penalty for a given match.
For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003. For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003.
@ -262,10 +270,12 @@ def calculate_penalty(gt: "Part", ocr: "Part", match: "Match",
if length_diff > 1: if length_diff > 1:
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start) substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
offset = length_diff / 2 - abs(substring_pos - length_diff / 2) offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
return (min_edit_dist * coef.edit_dist return (
min_edit_dist * coef.edit_dist
+ length_diff * coef.length_diff + length_diff * coef.length_diff
+ offset * coef.offset + offset * coef.offset
- substring_length * coef.length) - substring_length * coef.length
)
def character_accuracy_for_matches(matches: List["Match"]) -> float: def character_accuracy_for_matches(matches: List["Match"]) -> float:
@ -274,8 +284,9 @@ def character_accuracy_for_matches(matches: List["Match"]) -> float:
See other `character_accuracy` for details. See other `character_accuracy` for details.
""" """
agg: Counter = reduce(lambda acc, match: acc + Counter(match.dist._asdict()), agg: Counter = reduce(
matches, Counter()) lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
)
score = character_accuracy(Distance(**agg)) score = character_accuracy(Distance(**agg))
return score return score
@ -299,9 +310,9 @@ def character_accuracy(edits: "Distance") -> float:
chars = edits.match + edits.replace + edits.delete chars = edits.match + edits.replace + edits.delete
if not chars and not errors: if not chars and not errors:
# comparison of empty strings is considered a full match # comparison of empty strings is considered a full match
score = 1 score = 1.0
else: else:
score = 1 - errors / chars score = 1.0 - errors / chars
return score return score
@ -315,9 +326,11 @@ def initialize_lines(text: str) -> List["Part"]:
:param text: Text to split into lines. :param text: Text to split into lines.
:return: List of sorted line objects. :return: List of sorted line objects.
""" """
lines = [Part(text=line, line=i, start=0) lines = [
Part(text=line, line=i, start=0)
for i, line in enumerate(text.splitlines()) for i, line in enumerate(text.splitlines())
if len(line) > 0] if len(line) > 0
]
lines.sort(key=lambda x: x.length, reverse=True) lines.sort(key=lambda x: x.length, reverse=True)
return lines return lines
@ -348,6 +361,7 @@ class Part(NamedTuple):
This data object is maintained to be able to reproduce the original text. This data object is maintained to be able to reproduce the original text.
""" """
text: str = "" text: str = ""
line: int = 0 line: int = 0
start: int = 0 start: int = 0
@ -392,6 +406,7 @@ class Part(NamedTuple):
class Distance(NamedTuple): class Distance(NamedTuple):
"""Represent distance between two sequences.""" """Represent distance between two sequences."""
match: int = 0 match: int = 0
replace: int = 0 replace: int = 0
delete: int = 0 delete: int = 0
@ -400,6 +415,7 @@ class Distance(NamedTuple):
class Match(NamedTuple): class Match(NamedTuple):
"""Represent a calculated match between ground truth and the ocr result.""" """Represent a calculated match between ground truth and the ocr result."""
gt: "Part" gt: "Part"
ocr: "Part" ocr: "Part"
dist: "Distance" dist: "Distance"
@ -411,6 +427,7 @@ class Coefficients(NamedTuple):
See Section 3 in doi:10.1016/j.patrec.2020.02.003 See Section 3 in doi:10.1016/j.patrec.2020.02.003
""" """
edit_dist: int = 25 edit_dist: int = 25
length_diff: int = 20 length_diff: int = 20
offset: int = 1 offset: int = 1

@ -35,37 +35,31 @@ COMPLEX_CASES = [
EXTENDED_CASES = [ EXTENDED_CASES = [
# See figure 4 in 10.1016/j.patrec.2020.02.003 # See figure 4 in 10.1016/j.patrec.2020.02.003
# A: No errors # A: No errors
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), 1, 1),
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
1, 1),
# B: Different ordering of text blocks # B: Different ordering of text blocks
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), 1, 1),
(5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4),
1, 1),
# C: Merge across columns # C: Merge across columns
((0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9), (
(0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9),
(0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9), (0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9),
1, 0.964), 1,
0.964,
),
# D: Over-segmentation # D: Over-segmentation
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9), (0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9),
1, 0.966), 1,
0.966,
),
# E: Part missing # E: Part missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4), 1, 0.50),
(0, 1, 2, 3, 4),
1, 0.50),
# E.2: Part missing # E.2: Part missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9), 1, 0.50),
(5, 6, 7, 8, 9),
1, 0.50),
# F: All missing # F: All missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (), 1, 0),
(),
1, 0),
# G: Added parts # G: Added parts
((0, 1, 2, 3, 4), ((0, 1, 2, 3, 4), (0, 1, 2, 3, 4, 11, 5, 6), 1, 0.621),
(0, 1, 2, 3, 4, 11, 5, 6),
1, 0.621),
] ]
EDIT_ARGS = "gt,ocr,expected_dist" EDIT_ARGS = "gt,ocr,expected_dist"
@ -73,8 +67,11 @@ EDIT_ARGS = "gt,ocr,expected_dist"
SIMPLE_EDITS = [ SIMPLE_EDITS = [
(Part(text="a"), Part(text="a"), Distance(match=1)), (Part(text="a"), Part(text="a"), Distance(match=1)),
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)), (Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
(Part(text="abcd"), Part(text="beed"), (
Distance(match=2, replace=1, insert=1, delete=1)), Part(text="abcd"),
Part(text="beed"),
Distance(match=2, replace=1, insert=1, delete=1),
),
] ]
@ -83,9 +80,20 @@ def extended_case_to_text(gt, ocr):
See figure 4 in 10.1016/j.patrec.2020.02.003 See figure 4 in 10.1016/j.patrec.2020.02.003
""" """
sentence = ("Eight", "happy", "frogs", "scuba", "dived", sentence = (
"Jenny", "chick", "flaps", "white", "wings", "Eight",
"", "\n") "happy",
"frogs",
"scuba",
"dived",
"Jenny",
"chick",
"flaps",
"white",
"wings",
"",
"\n",
)
gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n") gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n")
ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n") ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n")
@ -98,22 +106,35 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_
assert score == pytest.approx(all_line_score) assert score == pytest.approx(all_line_score)
@pytest.mark.parametrize("config,ocr", [ @pytest.mark.parametrize(
("Config I", "config,ocr",
"1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein" [
(
"Config I",
"1 hav\nnospecial\ntalents.\n"
'I am one\npassionate\ncuriousity."\n'
"Alberto\nEmstein",
), ),
("Config II", (
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\"" "Config II",
'1 hav\nnospecial\ntalents. Alberto\n'
'I am one Emstein\npassionate\ncuriousity."',
), ),
("Config III", (
"Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"" "Config III",
'Alberto\nEmstein\n'
'1 hav\nnospecial\ntalents.\n'
'I am one\npassionate\ncuriousity."',
), ),
]) ],
)
def test_flexible_character_accuracy(config, ocr): def test_flexible_character_accuracy(config, ocr):
"""Tests from figure 3 in the paper.""" """Tests from figure 3 in the paper."""
gt = "\"I have\nno special\ntalent." \ gt = (
"\nI am only\npassionately\ncurious.\"" \ '"I have\nno special\ntalent.\n'
"\nAlbert\nEinstein" 'I am only\npassionately\ncurious."\n'
"Albert\nEinstein"
)
replacements, inserts, deletes = 3, 5, 7 replacements, inserts, deletes = 3, 5, 7
chars = len(gt) - gt.count("\n") chars = len(gt) - gt.count("\n")
assert chars == 68 assert chars == 68
@ -127,21 +148,27 @@ def test_flexible_character_accuracy(config, ocr):
replacements += 1 replacements += 1
deletes -= 1 deletes -= 1
expected_dist = Distance(match=chars - deletes - replacements, replace=replacements, expected_dist = Distance(
insert=inserts, delete=deletes) match=chars - deletes - replacements,
replace=replacements,
insert=inserts,
delete=deletes,
)
expected_score = character_accuracy(expected_dist) expected_score = character_accuracy(expected_dist)
result, matches = flexible_character_accuracy(gt, ocr) result, matches = flexible_character_accuracy(gt, ocr)
agg = reduce(lambda acc, match: acc + Counter(match.dist._asdict()), agg = reduce(
matches, Counter()) lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
)
dist = Distance(**agg) dist = Distance(**agg)
assert dist == expected_dist assert dist == expected_dist
assert result == pytest.approx(expected_score, abs=0.0005) assert result == pytest.approx(expected_score, abs=0.0005)
@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES) @pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES)
def test_flexible_character_accuracy_extended(gt, ocr, first_line_score, def test_flexible_character_accuracy_extended(
all_line_score): gt, ocr, first_line_score, all_line_score
):
"""Tests from figure 4 in the paper.""" """Tests from figure 4 in the paper."""
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr) gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence) result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
@ -170,10 +197,13 @@ def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score):
assert score == pytest.approx(first_line_score) assert score == pytest.approx(first_line_score)
@pytest.mark.parametrize(CASE_ARGS, [ @pytest.mark.parametrize(
CASE_ARGS,
[
*SIMPLE_CASES, *SIMPLE_CASES,
("accc", "a\nbb\nccc", 1.0, 1.0), ("accc", "a\nbb\nccc", 1.0, 1.0),
]) ],
)
def test_match_gt_line(gt, ocr, first_line_score, all_line_score): def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
coef = Coefficients() coef = Coefficients()
gt_lines = initialize_lines(gt) gt_lines = initialize_lines(gt)
@ -185,15 +215,22 @@ def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
assert score == pytest.approx(first_line_score) assert score == pytest.approx(first_line_score)
@pytest.mark.parametrize("original,match,expected_lines", [ @pytest.mark.parametrize(
"original,match,expected_lines",
[
(Part(), Part(), []), (Part(), Part(), []),
(Part(text="abc"), Part(), [Part(text="abc")]), (Part(text="abc"), Part(), [Part(text="abc")]),
(Part(text="abc"), Part("d"), [Part(text="bc", start=1)]), (Part(text="abc"), Part("d"), [Part(text="bc", start=1)]),
(Part(text="abc"), Part("a", start=100), [Part(text="abc")]), (Part(text="abc"), Part("a", start=100), [Part(text="abc")]),
(Part(text="abc"), Part("a"), [Part(text="bc", start=1)]), (Part(text="abc"), Part("a"), [Part(text="bc", start=1)]),
(Part(text="abc"), Part("b", start=1), [Part(text="a"), Part(text="c", start=2)]), (
Part(text="abc"),
Part("b", start=1),
[Part(text="a"), Part(text="c", start=2)],
),
(Part(text="abc"), Part("c", start=2), [Part(text="ab")]), (Part(text="abc"), Part("c", start=2), [Part(text="ab")]),
]) ],
)
def test_remove_or_split(original, match, expected_lines): def test_remove_or_split(original, match, expected_lines):
lines = [original] lines = [original]
splitted = remove_or_split(original, match, lines) splitted = remove_or_split(original, match, lines)
@ -201,7 +238,9 @@ def test_remove_or_split(original, match, expected_lines):
assert lines == expected_lines assert lines == expected_lines
@pytest.mark.parametrize(EDIT_ARGS, [ @pytest.mark.parametrize(
EDIT_ARGS,
[
*SIMPLE_EDITS, *SIMPLE_EDITS,
(Part(text="a"), Part(text="b"), Distance(delete=1)), (Part(text="a"), Part(text="b"), Distance(delete=1)),
(Part(text="aaa"), Part(text="bbb"), Distance(delete=3)), (Part(text="aaa"), Part(text="bbb"), Distance(delete=3)),
@ -210,9 +249,18 @@ def test_remove_or_split(original, match, expected_lines):
(Part(text=""), Part(text=""), None), (Part(text=""), Part(text=""), None),
(Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)), (Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)),
(Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)), (Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)),
(Part(text="aaabbbaaaddd"), Part(text="aaabcbaaa"), Distance(match=8, replace=1)), (
(Part(text="aaabbbccc"), Part(text="aaabbbdddccc"), Distance(match=9, insert=3)), Part(text="aaabbbaaaddd"),
]) Part(text="aaabcbaaa"),
Distance(match=8, replace=1),
),
(
Part(text="aaabbbccc"),
Part(text="aaabbbdddccc"),
Distance(match=9, insert=3),
),
],
)
def test_match_lines(gt, ocr, expected_dist): def test_match_lines(gt, ocr, expected_dist):
match = match_lines(gt, ocr) match = match_lines(gt, ocr)
if not expected_dist: if not expected_dist:
@ -223,14 +271,17 @@ def test_match_lines(gt, ocr, expected_dist):
assert match.dist == expected_dist assert match.dist == expected_dist
@pytest.mark.parametrize(EDIT_ARGS, [ @pytest.mark.parametrize(
EDIT_ARGS,
[
*SIMPLE_EDITS, *SIMPLE_EDITS,
(Part(text="").substring(), Part(text=""), Distance()), (Part(text="").substring(), Part(text=""), Distance()),
(Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)), (Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)),
(Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)), (Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)),
(Part(text="a"), Part(text="b"), Distance(replace=1)), (Part(text="a"), Part(text="b"), Distance(replace=1)),
(Part(text="aaa"), Part(text="bbb"), Distance(replace=3)), (Part(text="aaa"), Part(text="bbb"), Distance(replace=3)),
]) ],
)
def test_distance(gt, ocr, expected_dist): def test_distance(gt, ocr, expected_dist):
match = distance(gt, ocr) match = distance(gt, ocr)
assert match.gt == gt assert match.gt == gt
@ -238,20 +289,37 @@ def test_distance(gt, ocr, expected_dist):
assert match.dist == expected_dist assert match.dist == expected_dist
@pytest.mark.parametrize("matches,expected_dist", [ @pytest.mark.parametrize(
"matches,expected_dist",
[
([], 1), ([], 1),
([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1), ([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1),
([Match(gt=Part(text="abee"), ocr=Part("ac"), (
dist=Distance(match=1, replace=1, delete=2), ops=[]), [
Match(gt=Part(text="cd"), ocr=Part("ceff"), Match(
dist=Distance(match=1, replace=1, insert=2), ops=[])], gt=Part(text="abee"),
1 - 6 / 6), ocr=Part("ac"),
]) dist=Distance(match=1, replace=1, delete=2),
ops=[],
),
Match(
gt=Part(text="cd"),
ocr=Part("ceff"),
dist=Distance(match=1, replace=1, insert=2),
ops=[],
),
],
1 - 6 / 6,
),
],
)
def test_character_accuracy_matches(matches, expected_dist): def test_character_accuracy_matches(matches, expected_dist):
assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist) assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist)
@pytest.mark.parametrize("dist,expected_dist", [ @pytest.mark.parametrize(
"dist,expected_dist",
[
(Distance(), 1), (Distance(), 1),
(Distance(match=1), 1), (Distance(match=1), 1),
(Distance(replace=1), 0), (Distance(replace=1), 0),
@ -259,25 +327,39 @@ def test_character_accuracy_matches(matches, expected_dist):
(Distance(match=1, insert=2), 1 - 2 / 1), (Distance(match=1, insert=2), 1 - 2 / 1),
(Distance(match=2, insert=1), 0.5), (Distance(match=2, insert=1), 0.5),
(Distance(match=1, delete=1), 0.5), (Distance(match=1, delete=1), 0.5),
]) ],
)
def test_character_accuracy_dist(dist, expected_dist): def test_character_accuracy_dist(dist, expected_dist):
assert character_accuracy(dist) == pytest.approx(expected_dist) assert character_accuracy(dist) == pytest.approx(expected_dist)
@pytest.mark.parametrize("line,subline,expected_rest", [ @pytest.mark.parametrize(
"line,subline,expected_rest",
[
(Part(), Part(), []), (Part(), Part(), []),
(Part("aaa bbb"), Part("aaa bbb"), []), (Part("aaa bbb"), Part("aaa bbb"), []),
(Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]), (Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]),
(Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]), (Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]),
(Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]), (Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]),
(Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]), (Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]),
(Part("aaa bbb ccc"), Part("bbb", start=4), [Part("aaa "), Part(" ccc", start=7)]), (
(Part("aaa bbb ccc", start=3), Part("bbb", start=7), Part("aaa bbb ccc"),
[Part("aaa ", start=3), Part(" ccc", start=10)]), Part("bbb", start=4),
[Part("aaa "), Part(" ccc", start=7)],
),
(
Part("aaa bbb ccc", start=3),
Part("bbb", start=7),
[Part("aaa ", start=3), Part(" ccc", start=10)],
),
(Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]), (Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]),
(Part("aaa bbb", start=3), Part(" ", start=6), (
[Part("aaa", start=3), Part("bbb", start=7)]), Part("aaa bbb", start=3),
]) Part(" ", start=6),
[Part("aaa", start=3), Part("bbb", start=7)],
),
],
)
def test_split_line(line, subline, expected_rest): def test_split_line(line, subline, expected_rest):
rest = line.split(subline) rest = line.split(subline)
assert len(rest) == len(expected_rest) assert len(rest) == len(expected_rest)
@ -300,12 +382,15 @@ def test_combine_lines():
assert False assert False
@pytest.mark.parametrize("line,start,end,expected", [ @pytest.mark.parametrize(
"line,start,end,expected",
[
(Part(text=""), 0, None, Part(text="")), (Part(text=""), 0, None, Part(text="")),
(Part(text="a"), 0, None, Part(text="a")), (Part(text="a"), 0, None, Part(text="a")),
(Part(text="ab"), 0, 1, Part(text="a")), (Part(text="ab"), 0, 1, Part(text="a")),
(Part(text="abc"), 0, -1, Part(text="ab")), (Part(text="abc"), 0, -1, Part(text="ab")),
(Part(text="ab"), 1, None, Part(text="b", start=1)), (Part(text="ab"), 1, None, Part(text="b", start=1)),
]) ],
)
def test_line_substring(line, start, end, expected): def test_line_substring(line, start, end, expected):
assert line.substring(rel_start=start, rel_end=end) == expected assert line.substring(rel_start=start, rel_end=end) == expected

Loading…
Cancel
Save