mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-08 11:20:26 +02:00
Reformat using black
This commit is contained in:
parent
5277593bdb
commit
2a215a1062
2 changed files with 273 additions and 171 deletions
|
@ -29,20 +29,16 @@ def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"]
|
|||
:return: Score between 0 and 1 and match objects.
|
||||
"""
|
||||
|
||||
best_score = -float('inf')
|
||||
best_score = -float("inf")
|
||||
best_matches = []
|
||||
# TODO: this should be configurable
|
||||
combinations = product(range(15, 31, 5),
|
||||
range(0, 24, 3),
|
||||
range(0, 4, 1),
|
||||
range(0, 6, 1))
|
||||
combinations = product(
|
||||
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
|
||||
)
|
||||
# TODO: place to parallelize the algorithm
|
||||
for (edit_dist, length_diff, offset, length) in combinations:
|
||||
coef = Coefficients(
|
||||
edit_dist=edit_dist,
|
||||
length_diff=length_diff,
|
||||
offset=offset,
|
||||
length=length
|
||||
edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
|
||||
)
|
||||
# Steps 1 - 6 of the flexible character accuracy algorithm.
|
||||
matches = match_with_coefficients(gt, ocr, coef)
|
||||
|
@ -79,17 +75,21 @@ def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Ma
|
|||
|
||||
# Step 6 of the flexible character accuracy algorithm.
|
||||
# remaining lines are considered as deletes and inserts
|
||||
deletes = [distance(line, Part(text="", line=line.line, start=line.start))
|
||||
for line in gt_lines]
|
||||
inserts = [distance(Part(text="", line=line.line, start=line.start), line)
|
||||
for line in ocr_lines]
|
||||
deletes = [
|
||||
distance(line, Part(text="", line=line.line, start=line.start))
|
||||
for line in gt_lines
|
||||
]
|
||||
inserts = [
|
||||
distance(Part(text="", line=line.line, start=line.start), line)
|
||||
for line in ocr_lines
|
||||
]
|
||||
|
||||
return [*matches, *deletes, *inserts]
|
||||
|
||||
|
||||
def match_longest_gt_lines(gt_lines: List["Part"],
|
||||
ocr_lines: List["Part"],
|
||||
coef: "Coefficients") -> Optional["Match"]:
|
||||
def match_longest_gt_lines(
|
||||
gt_lines: List["Part"], ocr_lines: List["Part"], coef: "Coefficients"
|
||||
) -> Optional["Match"]:
|
||||
"""Find the best match for the longest line(s) in ground truth.
|
||||
|
||||
The longest lines in ground truth are matched against lines in ocr to find the
|
||||
|
@ -99,7 +99,7 @@ def match_longest_gt_lines(gt_lines: List["Part"],
|
|||
|
||||
:return: Possible match object.
|
||||
"""
|
||||
best_score, best_match, best_gt, best_ocr = -float('inf'), None, None, None
|
||||
best_score, best_match, best_gt, best_ocr = -float("inf"), None, None, None
|
||||
if not ocr_lines:
|
||||
return best_match
|
||||
|
||||
|
@ -126,10 +126,9 @@ def match_longest_gt_lines(gt_lines: List["Part"],
|
|||
return best_match
|
||||
|
||||
|
||||
def match_gt_line(gt_line: "Part",
|
||||
ocr_lines: List["Part"],
|
||||
coef: "Coefficients") -> Tuple[Optional["Match"],
|
||||
Optional["Part"]]:
|
||||
def match_gt_line(
|
||||
gt_line: "Part", ocr_lines: List["Part"], coef: "Coefficients"
|
||||
) -> Tuple[Optional["Match"], Optional["Part"]]:
|
||||
"""Match the given ground truth line against all the lines in ocr.
|
||||
|
||||
Reference: contains steps 3 of the flexible character accuracy algorithm.
|
||||
|
@ -138,19 +137,18 @@ def match_gt_line(gt_line: "Part",
|
|||
|
||||
:return: Match object and the matched ocr line.
|
||||
"""
|
||||
min_penalty = float('inf')
|
||||
min_penalty = float("inf")
|
||||
best_match, best_ocr = None, None
|
||||
for ocr_line in [*ocr_lines]:
|
||||
match = match_lines(gt_line, ocr_line)
|
||||
if match:
|
||||
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
|
||||
if penalty < min_penalty:
|
||||
min_penalty, best_match, best_ocr = penalty, match, ocr_line
|
||||
return best_match, best_ocr
|
||||
|
||||
|
||||
def remove_or_split(original: "Part",
|
||||
match: "Part",
|
||||
lines: List["Part"]) -> bool:
|
||||
def remove_or_split(original: "Part", match: "Part", lines: List["Part"]) -> bool:
|
||||
"""Removes the matched line or splits it into parts.
|
||||
|
||||
Reference: contains step 4 of the flexible character accuracy algorithm.
|
||||
|
@ -187,17 +185,24 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
|
|||
if min_length == 0:
|
||||
return best_match
|
||||
length_diff = gt_line.length - ocr_line.length
|
||||
min_edit_dist = float('inf')
|
||||
min_edit_dist = float("inf")
|
||||
|
||||
gt_parts = [(i, gt_line.substring(rel_start=i, rel_end=i + min_length))
|
||||
for i in range(0, max(1, length_diff + 1))]
|
||||
ocr_parts = [(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
|
||||
for j in range(0, max(1, -1 * length_diff + 1))]
|
||||
gt_parts = [
|
||||
(i, gt_line.substring(rel_start=i, rel_end=i + min_length))
|
||||
for i in range(0, max(1, length_diff + 1))
|
||||
]
|
||||
ocr_parts = [
|
||||
(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
|
||||
for j in range(0, max(1, -1 * length_diff + 1))
|
||||
]
|
||||
|
||||
# add full line and empty line match
|
||||
gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)]
|
||||
ocr_parts = [*ocr_parts, (0, ocr_line),
|
||||
(0, Part(text="", line=gt_line.line, start=gt_line.start))]
|
||||
ocr_parts = [
|
||||
*ocr_parts,
|
||||
(0, ocr_line),
|
||||
(0, Part(text="", line=gt_line.line, start=gt_line.start)),
|
||||
]
|
||||
|
||||
for i, gt_part in gt_parts:
|
||||
for j, ocr_part in ocr_parts:
|
||||
|
@ -211,8 +216,10 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
|
|||
part_length = best_match.gt.length
|
||||
additional_length = best_match.dist.delete + best_match.dist.replace
|
||||
for k in range(part_length + 1, part_length + additional_length + 1):
|
||||
match = distance(gt_line.substring(rel_start=best_i, rel_end=best_i + k),
|
||||
ocr_line.substring(rel_start=best_j, rel_end=best_j + k))
|
||||
match = distance(
|
||||
gt_line.substring(rel_start=best_i, rel_end=best_i + k),
|
||||
ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
|
||||
)
|
||||
edit_dist = score_edit_distance(match)
|
||||
if edit_dist < min_edit_dist:
|
||||
min_edit_dist = edit_dist
|
||||
|
@ -247,8 +254,9 @@ def score_edit_distance(match: "Match") -> int:
|
|||
return match.dist.delete + match.dist.insert + 2 * match.dist.replace
|
||||
|
||||
|
||||
def calculate_penalty(gt: "Part", ocr: "Part", match: "Match",
|
||||
coef: "Coefficients") -> float:
|
||||
def calculate_penalty(
|
||||
gt: "Part", ocr: "Part", match: "Match", coef: "Coefficients"
|
||||
) -> float:
|
||||
"""Calculate the penalty for a given match.
|
||||
|
||||
For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003.
|
||||
|
@ -262,10 +270,12 @@ def calculate_penalty(gt: "Part", ocr: "Part", match: "Match",
|
|||
if length_diff > 1:
|
||||
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
|
||||
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
|
||||
return (min_edit_dist * coef.edit_dist
|
||||
return (
|
||||
min_edit_dist * coef.edit_dist
|
||||
+ length_diff * coef.length_diff
|
||||
+ offset * coef.offset
|
||||
- substring_length * coef.length)
|
||||
- substring_length * coef.length
|
||||
)
|
||||
|
||||
|
||||
def character_accuracy_for_matches(matches: List["Match"]) -> float:
|
||||
|
@ -274,8 +284,9 @@ def character_accuracy_for_matches(matches: List["Match"]) -> float:
|
|||
See other `character_accuracy` for details.
|
||||
|
||||
"""
|
||||
agg: Counter = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
|
||||
matches, Counter())
|
||||
agg: Counter = reduce(
|
||||
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
|
||||
)
|
||||
|
||||
score = character_accuracy(Distance(**agg))
|
||||
return score
|
||||
|
@ -299,9 +310,9 @@ def character_accuracy(edits: "Distance") -> float:
|
|||
chars = edits.match + edits.replace + edits.delete
|
||||
if not chars and not errors:
|
||||
# comparison of empty strings is considered a full match
|
||||
score = 1
|
||||
score = 1.0
|
||||
else:
|
||||
score = 1 - errors / chars
|
||||
score = 1.0 - errors / chars
|
||||
return score
|
||||
|
||||
|
||||
|
@ -315,9 +326,11 @@ def initialize_lines(text: str) -> List["Part"]:
|
|||
:param text: Text to split into lines.
|
||||
:return: List of sorted line objects.
|
||||
"""
|
||||
lines = [Part(text=line, line=i, start=0)
|
||||
lines = [
|
||||
Part(text=line, line=i, start=0)
|
||||
for i, line in enumerate(text.splitlines())
|
||||
if len(line) > 0]
|
||||
if len(line) > 0
|
||||
]
|
||||
lines.sort(key=lambda x: x.length, reverse=True)
|
||||
return lines
|
||||
|
||||
|
@ -348,6 +361,7 @@ class Part(NamedTuple):
|
|||
|
||||
This data object is maintained to be able to reproduce the original text.
|
||||
"""
|
||||
|
||||
text: str = ""
|
||||
line: int = 0
|
||||
start: int = 0
|
||||
|
@ -392,6 +406,7 @@ class Part(NamedTuple):
|
|||
|
||||
class Distance(NamedTuple):
|
||||
"""Represent distance between two sequences."""
|
||||
|
||||
match: int = 0
|
||||
replace: int = 0
|
||||
delete: int = 0
|
||||
|
@ -400,6 +415,7 @@ class Distance(NamedTuple):
|
|||
|
||||
class Match(NamedTuple):
|
||||
"""Represent a calculated match between ground truth and the ocr result."""
|
||||
|
||||
gt: "Part"
|
||||
ocr: "Part"
|
||||
dist: "Distance"
|
||||
|
@ -411,6 +427,7 @@ class Coefficients(NamedTuple):
|
|||
|
||||
See Section 3 in doi:10.1016/j.patrec.2020.02.003
|
||||
"""
|
||||
|
||||
edit_dist: int = 25
|
||||
length_diff: int = 20
|
||||
offset: int = 1
|
||||
|
|
|
@ -35,37 +35,31 @@ COMPLEX_CASES = [
|
|||
EXTENDED_CASES = [
|
||||
# See figure 4 in 10.1016/j.patrec.2020.02.003
|
||||
# A: No errors
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||
1, 1),
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), 1, 1),
|
||||
# B: Different ordering of text blocks
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||
(5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4),
|
||||
1, 1),
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), 1, 1),
|
||||
# C: Merge across columns
|
||||
((0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9),
|
||||
(
|
||||
(0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9),
|
||||
(0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9),
|
||||
1, 0.964),
|
||||
1,
|
||||
0.964,
|
||||
),
|
||||
# D: Over-segmentation
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||
(
|
||||
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||
(0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9),
|
||||
1, 0.966),
|
||||
1,
|
||||
0.966,
|
||||
),
|
||||
# E: Part missing
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||
(0, 1, 2, 3, 4),
|
||||
1, 0.50),
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4), 1, 0.50),
|
||||
# E.2: Part missing
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||
(5, 6, 7, 8, 9),
|
||||
1, 0.50),
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9), 1, 0.50),
|
||||
# F: All missing
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||
(),
|
||||
1, 0),
|
||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (), 1, 0),
|
||||
# G: Added parts
|
||||
((0, 1, 2, 3, 4),
|
||||
(0, 1, 2, 3, 4, 11, 5, 6),
|
||||
1, 0.621),
|
||||
((0, 1, 2, 3, 4), (0, 1, 2, 3, 4, 11, 5, 6), 1, 0.621),
|
||||
]
|
||||
|
||||
EDIT_ARGS = "gt,ocr,expected_dist"
|
||||
|
@ -73,8 +67,11 @@ EDIT_ARGS = "gt,ocr,expected_dist"
|
|||
SIMPLE_EDITS = [
|
||||
(Part(text="a"), Part(text="a"), Distance(match=1)),
|
||||
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
|
||||
(Part(text="abcd"), Part(text="beed"),
|
||||
Distance(match=2, replace=1, insert=1, delete=1)),
|
||||
(
|
||||
Part(text="abcd"),
|
||||
Part(text="beed"),
|
||||
Distance(match=2, replace=1, insert=1, delete=1),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
@ -83,9 +80,20 @@ def extended_case_to_text(gt, ocr):
|
|||
|
||||
See figure 4 in 10.1016/j.patrec.2020.02.003
|
||||
"""
|
||||
sentence = ("Eight", "happy", "frogs", "scuba", "dived",
|
||||
"Jenny", "chick", "flaps", "white", "wings",
|
||||
"", "\n")
|
||||
sentence = (
|
||||
"Eight",
|
||||
"happy",
|
||||
"frogs",
|
||||
"scuba",
|
||||
"dived",
|
||||
"Jenny",
|
||||
"chick",
|
||||
"flaps",
|
||||
"white",
|
||||
"wings",
|
||||
"",
|
||||
"\n",
|
||||
)
|
||||
|
||||
gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n")
|
||||
ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n")
|
||||
|
@ -98,22 +106,35 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_
|
|||
assert score == pytest.approx(all_line_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config,ocr", [
|
||||
("Config I",
|
||||
"1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein"
|
||||
@pytest.mark.parametrize(
|
||||
"config,ocr",
|
||||
[
|
||||
(
|
||||
"Config I",
|
||||
"1 hav\nnospecial\ntalents.\n"
|
||||
'I am one\npassionate\ncuriousity."\n'
|
||||
"Alberto\nEmstein",
|
||||
),
|
||||
("Config II",
|
||||
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\""
|
||||
(
|
||||
"Config II",
|
||||
'1 hav\nnospecial\ntalents. Alberto\n'
|
||||
'I am one Emstein\npassionate\ncuriousity."',
|
||||
),
|
||||
("Config III",
|
||||
"Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\""
|
||||
(
|
||||
"Config III",
|
||||
'Alberto\nEmstein\n'
|
||||
'1 hav\nnospecial\ntalents.\n'
|
||||
'I am one\npassionate\ncuriousity."',
|
||||
),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_flexible_character_accuracy(config, ocr):
|
||||
"""Tests from figure 3 in the paper."""
|
||||
gt = "\"I have\nno special\ntalent." \
|
||||
"\nI am only\npassionately\ncurious.\"" \
|
||||
"\nAlbert\nEinstein"
|
||||
gt = (
|
||||
'"I have\nno special\ntalent.\n'
|
||||
'I am only\npassionately\ncurious."\n'
|
||||
"Albert\nEinstein"
|
||||
)
|
||||
replacements, inserts, deletes = 3, 5, 7
|
||||
chars = len(gt) - gt.count("\n")
|
||||
assert chars == 68
|
||||
|
@ -127,21 +148,27 @@ def test_flexible_character_accuracy(config, ocr):
|
|||
replacements += 1
|
||||
deletes -= 1
|
||||
|
||||
expected_dist = Distance(match=chars - deletes - replacements, replace=replacements,
|
||||
insert=inserts, delete=deletes)
|
||||
expected_dist = Distance(
|
||||
match=chars - deletes - replacements,
|
||||
replace=replacements,
|
||||
insert=inserts,
|
||||
delete=deletes,
|
||||
)
|
||||
expected_score = character_accuracy(expected_dist)
|
||||
|
||||
result, matches = flexible_character_accuracy(gt, ocr)
|
||||
agg = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
|
||||
matches, Counter())
|
||||
agg = reduce(
|
||||
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
|
||||
)
|
||||
dist = Distance(**agg)
|
||||
assert dist == expected_dist
|
||||
assert result == pytest.approx(expected_score, abs=0.0005)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES)
|
||||
def test_flexible_character_accuracy_extended(gt, ocr, first_line_score,
|
||||
all_line_score):
|
||||
def test_flexible_character_accuracy_extended(
|
||||
gt, ocr, first_line_score, all_line_score
|
||||
):
|
||||
"""Tests from figure 4 in the paper."""
|
||||
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
|
||||
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
|
||||
|
@ -170,10 +197,13 @@ def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score):
|
|||
assert score == pytest.approx(first_line_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(CASE_ARGS, [
|
||||
@pytest.mark.parametrize(
|
||||
CASE_ARGS,
|
||||
[
|
||||
*SIMPLE_CASES,
|
||||
("accc", "a\nbb\nccc", 1.0, 1.0),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
|
||||
coef = Coefficients()
|
||||
gt_lines = initialize_lines(gt)
|
||||
|
@ -185,15 +215,22 @@ def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
|
|||
assert score == pytest.approx(first_line_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("original,match,expected_lines", [
|
||||
@pytest.mark.parametrize(
|
||||
"original,match,expected_lines",
|
||||
[
|
||||
(Part(), Part(), []),
|
||||
(Part(text="abc"), Part(), [Part(text="abc")]),
|
||||
(Part(text="abc"), Part("d"), [Part(text="bc", start=1)]),
|
||||
(Part(text="abc"), Part("a", start=100), [Part(text="abc")]),
|
||||
(Part(text="abc"), Part("a"), [Part(text="bc", start=1)]),
|
||||
(Part(text="abc"), Part("b", start=1), [Part(text="a"), Part(text="c", start=2)]),
|
||||
(
|
||||
Part(text="abc"),
|
||||
Part("b", start=1),
|
||||
[Part(text="a"), Part(text="c", start=2)],
|
||||
),
|
||||
(Part(text="abc"), Part("c", start=2), [Part(text="ab")]),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_remove_or_split(original, match, expected_lines):
|
||||
lines = [original]
|
||||
splitted = remove_or_split(original, match, lines)
|
||||
|
@ -201,7 +238,9 @@ def test_remove_or_split(original, match, expected_lines):
|
|||
assert lines == expected_lines
|
||||
|
||||
|
||||
@pytest.mark.parametrize(EDIT_ARGS, [
|
||||
@pytest.mark.parametrize(
|
||||
EDIT_ARGS,
|
||||
[
|
||||
*SIMPLE_EDITS,
|
||||
(Part(text="a"), Part(text="b"), Distance(delete=1)),
|
||||
(Part(text="aaa"), Part(text="bbb"), Distance(delete=3)),
|
||||
|
@ -210,9 +249,18 @@ def test_remove_or_split(original, match, expected_lines):
|
|||
(Part(text=""), Part(text=""), None),
|
||||
(Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)),
|
||||
(Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)),
|
||||
(Part(text="aaabbbaaaddd"), Part(text="aaabcbaaa"), Distance(match=8, replace=1)),
|
||||
(Part(text="aaabbbccc"), Part(text="aaabbbdddccc"), Distance(match=9, insert=3)),
|
||||
])
|
||||
(
|
||||
Part(text="aaabbbaaaddd"),
|
||||
Part(text="aaabcbaaa"),
|
||||
Distance(match=8, replace=1),
|
||||
),
|
||||
(
|
||||
Part(text="aaabbbccc"),
|
||||
Part(text="aaabbbdddccc"),
|
||||
Distance(match=9, insert=3),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_match_lines(gt, ocr, expected_dist):
|
||||
match = match_lines(gt, ocr)
|
||||
if not expected_dist:
|
||||
|
@ -223,14 +271,17 @@ def test_match_lines(gt, ocr, expected_dist):
|
|||
assert match.dist == expected_dist
|
||||
|
||||
|
||||
@pytest.mark.parametrize(EDIT_ARGS, [
|
||||
@pytest.mark.parametrize(
|
||||
EDIT_ARGS,
|
||||
[
|
||||
*SIMPLE_EDITS,
|
||||
(Part(text="").substring(), Part(text=""), Distance()),
|
||||
(Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)),
|
||||
(Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)),
|
||||
(Part(text="a"), Part(text="b"), Distance(replace=1)),
|
||||
(Part(text="aaa"), Part(text="bbb"), Distance(replace=3)),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_distance(gt, ocr, expected_dist):
|
||||
match = distance(gt, ocr)
|
||||
assert match.gt == gt
|
||||
|
@ -238,20 +289,37 @@ def test_distance(gt, ocr, expected_dist):
|
|||
assert match.dist == expected_dist
|
||||
|
||||
|
||||
@pytest.mark.parametrize("matches,expected_dist", [
|
||||
@pytest.mark.parametrize(
|
||||
"matches,expected_dist",
|
||||
[
|
||||
([], 1),
|
||||
([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1),
|
||||
([Match(gt=Part(text="abee"), ocr=Part("ac"),
|
||||
dist=Distance(match=1, replace=1, delete=2), ops=[]),
|
||||
Match(gt=Part(text="cd"), ocr=Part("ceff"),
|
||||
dist=Distance(match=1, replace=1, insert=2), ops=[])],
|
||||
1 - 6 / 6),
|
||||
])
|
||||
(
|
||||
[
|
||||
Match(
|
||||
gt=Part(text="abee"),
|
||||
ocr=Part("ac"),
|
||||
dist=Distance(match=1, replace=1, delete=2),
|
||||
ops=[],
|
||||
),
|
||||
Match(
|
||||
gt=Part(text="cd"),
|
||||
ocr=Part("ceff"),
|
||||
dist=Distance(match=1, replace=1, insert=2),
|
||||
ops=[],
|
||||
),
|
||||
],
|
||||
1 - 6 / 6,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_character_accuracy_matches(matches, expected_dist):
|
||||
assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dist,expected_dist", [
|
||||
@pytest.mark.parametrize(
|
||||
"dist,expected_dist",
|
||||
[
|
||||
(Distance(), 1),
|
||||
(Distance(match=1), 1),
|
||||
(Distance(replace=1), 0),
|
||||
|
@ -259,25 +327,39 @@ def test_character_accuracy_matches(matches, expected_dist):
|
|||
(Distance(match=1, insert=2), 1 - 2 / 1),
|
||||
(Distance(match=2, insert=1), 0.5),
|
||||
(Distance(match=1, delete=1), 0.5),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_character_accuracy_dist(dist, expected_dist):
|
||||
assert character_accuracy(dist) == pytest.approx(expected_dist)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("line,subline,expected_rest", [
|
||||
@pytest.mark.parametrize(
|
||||
"line,subline,expected_rest",
|
||||
[
|
||||
(Part(), Part(), []),
|
||||
(Part("aaa bbb"), Part("aaa bbb"), []),
|
||||
(Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]),
|
||||
(Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]),
|
||||
(Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]),
|
||||
(Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]),
|
||||
(Part("aaa bbb ccc"), Part("bbb", start=4), [Part("aaa "), Part(" ccc", start=7)]),
|
||||
(Part("aaa bbb ccc", start=3), Part("bbb", start=7),
|
||||
[Part("aaa ", start=3), Part(" ccc", start=10)]),
|
||||
(
|
||||
Part("aaa bbb ccc"),
|
||||
Part("bbb", start=4),
|
||||
[Part("aaa "), Part(" ccc", start=7)],
|
||||
),
|
||||
(
|
||||
Part("aaa bbb ccc", start=3),
|
||||
Part("bbb", start=7),
|
||||
[Part("aaa ", start=3), Part(" ccc", start=10)],
|
||||
),
|
||||
(Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]),
|
||||
(Part("aaa bbb", start=3), Part(" ", start=6),
|
||||
[Part("aaa", start=3), Part("bbb", start=7)]),
|
||||
])
|
||||
(
|
||||
Part("aaa bbb", start=3),
|
||||
Part(" ", start=6),
|
||||
[Part("aaa", start=3), Part("bbb", start=7)],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_line(line, subline, expected_rest):
|
||||
rest = line.split(subline)
|
||||
assert len(rest) == len(expected_rest)
|
||||
|
@ -300,12 +382,15 @@ def test_combine_lines():
|
|||
assert False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("line,start,end,expected", [
|
||||
@pytest.mark.parametrize(
|
||||
"line,start,end,expected",
|
||||
[
|
||||
(Part(text=""), 0, None, Part(text="")),
|
||||
(Part(text="a"), 0, None, Part(text="a")),
|
||||
(Part(text="ab"), 0, 1, Part(text="a")),
|
||||
(Part(text="abc"), 0, -1, Part(text="ab")),
|
||||
(Part(text="ab"), 1, None, Part(text="b", start=1)),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_line_substring(line, start, end, expected):
|
||||
assert line.substring(rel_start=start, rel_end=end) == expected
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue