mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-08 11:20:26 +02:00
Reformat using black
This commit is contained in:
parent
5277593bdb
commit
2a215a1062
2 changed files with 273 additions and 171 deletions
|
@ -29,20 +29,16 @@ def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"]
|
||||||
:return: Score between 0 and 1 and match objects.
|
:return: Score between 0 and 1 and match objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
best_score = -float('inf')
|
best_score = -float("inf")
|
||||||
best_matches = []
|
best_matches = []
|
||||||
# TODO: this should be configurable
|
# TODO: this should be configurable
|
||||||
combinations = product(range(15, 31, 5),
|
combinations = product(
|
||||||
range(0, 24, 3),
|
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
|
||||||
range(0, 4, 1),
|
)
|
||||||
range(0, 6, 1))
|
|
||||||
# TODO: place to parallelize the algorithm
|
# TODO: place to parallelize the algorithm
|
||||||
for (edit_dist, length_diff, offset, length) in combinations:
|
for (edit_dist, length_diff, offset, length) in combinations:
|
||||||
coef = Coefficients(
|
coef = Coefficients(
|
||||||
edit_dist=edit_dist,
|
edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
|
||||||
length_diff=length_diff,
|
|
||||||
offset=offset,
|
|
||||||
length=length
|
|
||||||
)
|
)
|
||||||
# Steps 1 - 6 of the flexible character accuracy algorithm.
|
# Steps 1 - 6 of the flexible character accuracy algorithm.
|
||||||
matches = match_with_coefficients(gt, ocr, coef)
|
matches = match_with_coefficients(gt, ocr, coef)
|
||||||
|
@ -79,17 +75,21 @@ def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Ma
|
||||||
|
|
||||||
# Step 6 of the flexible character accuracy algorithm.
|
# Step 6 of the flexible character accuracy algorithm.
|
||||||
# remaining lines are considered as deletes and inserts
|
# remaining lines are considered as deletes and inserts
|
||||||
deletes = [distance(line, Part(text="", line=line.line, start=line.start))
|
deletes = [
|
||||||
for line in gt_lines]
|
distance(line, Part(text="", line=line.line, start=line.start))
|
||||||
inserts = [distance(Part(text="", line=line.line, start=line.start), line)
|
for line in gt_lines
|
||||||
for line in ocr_lines]
|
]
|
||||||
|
inserts = [
|
||||||
|
distance(Part(text="", line=line.line, start=line.start), line)
|
||||||
|
for line in ocr_lines
|
||||||
|
]
|
||||||
|
|
||||||
return [*matches, *deletes, *inserts]
|
return [*matches, *deletes, *inserts]
|
||||||
|
|
||||||
|
|
||||||
def match_longest_gt_lines(gt_lines: List["Part"],
|
def match_longest_gt_lines(
|
||||||
ocr_lines: List["Part"],
|
gt_lines: List["Part"], ocr_lines: List["Part"], coef: "Coefficients"
|
||||||
coef: "Coefficients") -> Optional["Match"]:
|
) -> Optional["Match"]:
|
||||||
"""Find the best match for the longest line(s) in ground truth.
|
"""Find the best match for the longest line(s) in ground truth.
|
||||||
|
|
||||||
The longest lines in ground truth are matched against lines in ocr to find the
|
The longest lines in ground truth are matched against lines in ocr to find the
|
||||||
|
@ -99,7 +99,7 @@ def match_longest_gt_lines(gt_lines: List["Part"],
|
||||||
|
|
||||||
:return: Possible match object.
|
:return: Possible match object.
|
||||||
"""
|
"""
|
||||||
best_score, best_match, best_gt, best_ocr = -float('inf'), None, None, None
|
best_score, best_match, best_gt, best_ocr = -float("inf"), None, None, None
|
||||||
if not ocr_lines:
|
if not ocr_lines:
|
||||||
return best_match
|
return best_match
|
||||||
|
|
||||||
|
@ -126,10 +126,9 @@ def match_longest_gt_lines(gt_lines: List["Part"],
|
||||||
return best_match
|
return best_match
|
||||||
|
|
||||||
|
|
||||||
def match_gt_line(gt_line: "Part",
|
def match_gt_line(
|
||||||
ocr_lines: List["Part"],
|
gt_line: "Part", ocr_lines: List["Part"], coef: "Coefficients"
|
||||||
coef: "Coefficients") -> Tuple[Optional["Match"],
|
) -> Tuple[Optional["Match"], Optional["Part"]]:
|
||||||
Optional["Part"]]:
|
|
||||||
"""Match the given ground truth line against all the lines in ocr.
|
"""Match the given ground truth line against all the lines in ocr.
|
||||||
|
|
||||||
Reference: contains steps 3 of the flexible character accuracy algorithm.
|
Reference: contains steps 3 of the flexible character accuracy algorithm.
|
||||||
|
@ -138,19 +137,18 @@ def match_gt_line(gt_line: "Part",
|
||||||
|
|
||||||
:return: Match object and the matched ocr line.
|
:return: Match object and the matched ocr line.
|
||||||
"""
|
"""
|
||||||
min_penalty = float('inf')
|
min_penalty = float("inf")
|
||||||
best_match, best_ocr = None, None
|
best_match, best_ocr = None, None
|
||||||
for ocr_line in [*ocr_lines]:
|
for ocr_line in [*ocr_lines]:
|
||||||
match = match_lines(gt_line, ocr_line)
|
match = match_lines(gt_line, ocr_line)
|
||||||
|
if match:
|
||||||
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
|
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
|
||||||
if penalty < min_penalty:
|
if penalty < min_penalty:
|
||||||
min_penalty, best_match, best_ocr = penalty, match, ocr_line
|
min_penalty, best_match, best_ocr = penalty, match, ocr_line
|
||||||
return best_match, best_ocr
|
return best_match, best_ocr
|
||||||
|
|
||||||
|
|
||||||
def remove_or_split(original: "Part",
|
def remove_or_split(original: "Part", match: "Part", lines: List["Part"]) -> bool:
|
||||||
match: "Part",
|
|
||||||
lines: List["Part"]) -> bool:
|
|
||||||
"""Removes the matched line or splits it into parts.
|
"""Removes the matched line or splits it into parts.
|
||||||
|
|
||||||
Reference: contains step 4 of the flexible character accuracy algorithm.
|
Reference: contains step 4 of the flexible character accuracy algorithm.
|
||||||
|
@ -187,17 +185,24 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
|
||||||
if min_length == 0:
|
if min_length == 0:
|
||||||
return best_match
|
return best_match
|
||||||
length_diff = gt_line.length - ocr_line.length
|
length_diff = gt_line.length - ocr_line.length
|
||||||
min_edit_dist = float('inf')
|
min_edit_dist = float("inf")
|
||||||
|
|
||||||
gt_parts = [(i, gt_line.substring(rel_start=i, rel_end=i + min_length))
|
gt_parts = [
|
||||||
for i in range(0, max(1, length_diff + 1))]
|
(i, gt_line.substring(rel_start=i, rel_end=i + min_length))
|
||||||
ocr_parts = [(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
|
for i in range(0, max(1, length_diff + 1))
|
||||||
for j in range(0, max(1, -1 * length_diff + 1))]
|
]
|
||||||
|
ocr_parts = [
|
||||||
|
(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
|
||||||
|
for j in range(0, max(1, -1 * length_diff + 1))
|
||||||
|
]
|
||||||
|
|
||||||
# add full line and empty line match
|
# add full line and empty line match
|
||||||
gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)]
|
gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)]
|
||||||
ocr_parts = [*ocr_parts, (0, ocr_line),
|
ocr_parts = [
|
||||||
(0, Part(text="", line=gt_line.line, start=gt_line.start))]
|
*ocr_parts,
|
||||||
|
(0, ocr_line),
|
||||||
|
(0, Part(text="", line=gt_line.line, start=gt_line.start)),
|
||||||
|
]
|
||||||
|
|
||||||
for i, gt_part in gt_parts:
|
for i, gt_part in gt_parts:
|
||||||
for j, ocr_part in ocr_parts:
|
for j, ocr_part in ocr_parts:
|
||||||
|
@ -211,8 +216,10 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
|
||||||
part_length = best_match.gt.length
|
part_length = best_match.gt.length
|
||||||
additional_length = best_match.dist.delete + best_match.dist.replace
|
additional_length = best_match.dist.delete + best_match.dist.replace
|
||||||
for k in range(part_length + 1, part_length + additional_length + 1):
|
for k in range(part_length + 1, part_length + additional_length + 1):
|
||||||
match = distance(gt_line.substring(rel_start=best_i, rel_end=best_i + k),
|
match = distance(
|
||||||
ocr_line.substring(rel_start=best_j, rel_end=best_j + k))
|
gt_line.substring(rel_start=best_i, rel_end=best_i + k),
|
||||||
|
ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
|
||||||
|
)
|
||||||
edit_dist = score_edit_distance(match)
|
edit_dist = score_edit_distance(match)
|
||||||
if edit_dist < min_edit_dist:
|
if edit_dist < min_edit_dist:
|
||||||
min_edit_dist = edit_dist
|
min_edit_dist = edit_dist
|
||||||
|
@ -247,8 +254,9 @@ def score_edit_distance(match: "Match") -> int:
|
||||||
return match.dist.delete + match.dist.insert + 2 * match.dist.replace
|
return match.dist.delete + match.dist.insert + 2 * match.dist.replace
|
||||||
|
|
||||||
|
|
||||||
def calculate_penalty(gt: "Part", ocr: "Part", match: "Match",
|
def calculate_penalty(
|
||||||
coef: "Coefficients") -> float:
|
gt: "Part", ocr: "Part", match: "Match", coef: "Coefficients"
|
||||||
|
) -> float:
|
||||||
"""Calculate the penalty for a given match.
|
"""Calculate the penalty for a given match.
|
||||||
|
|
||||||
For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003.
|
For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003.
|
||||||
|
@ -262,10 +270,12 @@ def calculate_penalty(gt: "Part", ocr: "Part", match: "Match",
|
||||||
if length_diff > 1:
|
if length_diff > 1:
|
||||||
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
|
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
|
||||||
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
|
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
|
||||||
return (min_edit_dist * coef.edit_dist
|
return (
|
||||||
|
min_edit_dist * coef.edit_dist
|
||||||
+ length_diff * coef.length_diff
|
+ length_diff * coef.length_diff
|
||||||
+ offset * coef.offset
|
+ offset * coef.offset
|
||||||
- substring_length * coef.length)
|
- substring_length * coef.length
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def character_accuracy_for_matches(matches: List["Match"]) -> float:
|
def character_accuracy_for_matches(matches: List["Match"]) -> float:
|
||||||
|
@ -274,8 +284,9 @@ def character_accuracy_for_matches(matches: List["Match"]) -> float:
|
||||||
See other `character_accuracy` for details.
|
See other `character_accuracy` for details.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
agg: Counter = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
|
agg: Counter = reduce(
|
||||||
matches, Counter())
|
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
|
||||||
|
)
|
||||||
|
|
||||||
score = character_accuracy(Distance(**agg))
|
score = character_accuracy(Distance(**agg))
|
||||||
return score
|
return score
|
||||||
|
@ -299,9 +310,9 @@ def character_accuracy(edits: "Distance") -> float:
|
||||||
chars = edits.match + edits.replace + edits.delete
|
chars = edits.match + edits.replace + edits.delete
|
||||||
if not chars and not errors:
|
if not chars and not errors:
|
||||||
# comparison of empty strings is considered a full match
|
# comparison of empty strings is considered a full match
|
||||||
score = 1
|
score = 1.0
|
||||||
else:
|
else:
|
||||||
score = 1 - errors / chars
|
score = 1.0 - errors / chars
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,9 +326,11 @@ def initialize_lines(text: str) -> List["Part"]:
|
||||||
:param text: Text to split into lines.
|
:param text: Text to split into lines.
|
||||||
:return: List of sorted line objects.
|
:return: List of sorted line objects.
|
||||||
"""
|
"""
|
||||||
lines = [Part(text=line, line=i, start=0)
|
lines = [
|
||||||
|
Part(text=line, line=i, start=0)
|
||||||
for i, line in enumerate(text.splitlines())
|
for i, line in enumerate(text.splitlines())
|
||||||
if len(line) > 0]
|
if len(line) > 0
|
||||||
|
]
|
||||||
lines.sort(key=lambda x: x.length, reverse=True)
|
lines.sort(key=lambda x: x.length, reverse=True)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
@ -348,6 +361,7 @@ class Part(NamedTuple):
|
||||||
|
|
||||||
This data object is maintained to be able to reproduce the original text.
|
This data object is maintained to be able to reproduce the original text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: str = ""
|
text: str = ""
|
||||||
line: int = 0
|
line: int = 0
|
||||||
start: int = 0
|
start: int = 0
|
||||||
|
@ -392,6 +406,7 @@ class Part(NamedTuple):
|
||||||
|
|
||||||
class Distance(NamedTuple):
|
class Distance(NamedTuple):
|
||||||
"""Represent distance between two sequences."""
|
"""Represent distance between two sequences."""
|
||||||
|
|
||||||
match: int = 0
|
match: int = 0
|
||||||
replace: int = 0
|
replace: int = 0
|
||||||
delete: int = 0
|
delete: int = 0
|
||||||
|
@ -400,6 +415,7 @@ class Distance(NamedTuple):
|
||||||
|
|
||||||
class Match(NamedTuple):
|
class Match(NamedTuple):
|
||||||
"""Represent a calculated match between ground truth and the ocr result."""
|
"""Represent a calculated match between ground truth and the ocr result."""
|
||||||
|
|
||||||
gt: "Part"
|
gt: "Part"
|
||||||
ocr: "Part"
|
ocr: "Part"
|
||||||
dist: "Distance"
|
dist: "Distance"
|
||||||
|
@ -411,6 +427,7 @@ class Coefficients(NamedTuple):
|
||||||
|
|
||||||
See Section 3 in doi:10.1016/j.patrec.2020.02.003
|
See Section 3 in doi:10.1016/j.patrec.2020.02.003
|
||||||
"""
|
"""
|
||||||
|
|
||||||
edit_dist: int = 25
|
edit_dist: int = 25
|
||||||
length_diff: int = 20
|
length_diff: int = 20
|
||||||
offset: int = 1
|
offset: int = 1
|
||||||
|
|
|
@ -35,37 +35,31 @@ COMPLEX_CASES = [
|
||||||
EXTENDED_CASES = [
|
EXTENDED_CASES = [
|
||||||
# See figure 4 in 10.1016/j.patrec.2020.02.003
|
# See figure 4 in 10.1016/j.patrec.2020.02.003
|
||||||
# A: No errors
|
# A: No errors
|
||||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), 1, 1),
|
||||||
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
|
||||||
1, 1),
|
|
||||||
# B: Different ordering of text blocks
|
# B: Different ordering of text blocks
|
||||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), 1, 1),
|
||||||
(5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4),
|
|
||||||
1, 1),
|
|
||||||
# C: Merge across columns
|
# C: Merge across columns
|
||||||
((0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9),
|
(
|
||||||
|
(0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9),
|
||||||
(0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9),
|
(0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9),
|
||||||
1, 0.964),
|
1,
|
||||||
|
0.964,
|
||||||
|
),
|
||||||
# D: Over-segmentation
|
# D: Over-segmentation
|
||||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
(
|
||||||
|
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
||||||
(0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9),
|
(0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9),
|
||||||
1, 0.966),
|
1,
|
||||||
|
0.966,
|
||||||
|
),
|
||||||
# E: Part missing
|
# E: Part missing
|
||||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4), 1, 0.50),
|
||||||
(0, 1, 2, 3, 4),
|
|
||||||
1, 0.50),
|
|
||||||
# E.2: Part missing
|
# E.2: Part missing
|
||||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9), 1, 0.50),
|
||||||
(5, 6, 7, 8, 9),
|
|
||||||
1, 0.50),
|
|
||||||
# F: All missing
|
# F: All missing
|
||||||
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
|
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (), 1, 0),
|
||||||
(),
|
|
||||||
1, 0),
|
|
||||||
# G: Added parts
|
# G: Added parts
|
||||||
((0, 1, 2, 3, 4),
|
((0, 1, 2, 3, 4), (0, 1, 2, 3, 4, 11, 5, 6), 1, 0.621),
|
||||||
(0, 1, 2, 3, 4, 11, 5, 6),
|
|
||||||
1, 0.621),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
EDIT_ARGS = "gt,ocr,expected_dist"
|
EDIT_ARGS = "gt,ocr,expected_dist"
|
||||||
|
@ -73,8 +67,11 @@ EDIT_ARGS = "gt,ocr,expected_dist"
|
||||||
SIMPLE_EDITS = [
|
SIMPLE_EDITS = [
|
||||||
(Part(text="a"), Part(text="a"), Distance(match=1)),
|
(Part(text="a"), Part(text="a"), Distance(match=1)),
|
||||||
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
|
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
|
||||||
(Part(text="abcd"), Part(text="beed"),
|
(
|
||||||
Distance(match=2, replace=1, insert=1, delete=1)),
|
Part(text="abcd"),
|
||||||
|
Part(text="beed"),
|
||||||
|
Distance(match=2, replace=1, insert=1, delete=1),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,9 +80,20 @@ def extended_case_to_text(gt, ocr):
|
||||||
|
|
||||||
See figure 4 in 10.1016/j.patrec.2020.02.003
|
See figure 4 in 10.1016/j.patrec.2020.02.003
|
||||||
"""
|
"""
|
||||||
sentence = ("Eight", "happy", "frogs", "scuba", "dived",
|
sentence = (
|
||||||
"Jenny", "chick", "flaps", "white", "wings",
|
"Eight",
|
||||||
"", "\n")
|
"happy",
|
||||||
|
"frogs",
|
||||||
|
"scuba",
|
||||||
|
"dived",
|
||||||
|
"Jenny",
|
||||||
|
"chick",
|
||||||
|
"flaps",
|
||||||
|
"white",
|
||||||
|
"wings",
|
||||||
|
"",
|
||||||
|
"\n",
|
||||||
|
)
|
||||||
|
|
||||||
gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n")
|
gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n")
|
||||||
ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n")
|
ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n")
|
||||||
|
@ -98,22 +106,35 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_
|
||||||
assert score == pytest.approx(all_line_score)
|
assert score == pytest.approx(all_line_score)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("config,ocr", [
|
@pytest.mark.parametrize(
|
||||||
("Config I",
|
"config,ocr",
|
||||||
"1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein"
|
[
|
||||||
|
(
|
||||||
|
"Config I",
|
||||||
|
"1 hav\nnospecial\ntalents.\n"
|
||||||
|
'I am one\npassionate\ncuriousity."\n'
|
||||||
|
"Alberto\nEmstein",
|
||||||
),
|
),
|
||||||
("Config II",
|
(
|
||||||
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\""
|
"Config II",
|
||||||
|
'1 hav\nnospecial\ntalents. Alberto\n'
|
||||||
|
'I am one Emstein\npassionate\ncuriousity."',
|
||||||
),
|
),
|
||||||
("Config III",
|
(
|
||||||
"Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\""
|
"Config III",
|
||||||
|
'Alberto\nEmstein\n'
|
||||||
|
'1 hav\nnospecial\ntalents.\n'
|
||||||
|
'I am one\npassionate\ncuriousity."',
|
||||||
),
|
),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_flexible_character_accuracy(config, ocr):
|
def test_flexible_character_accuracy(config, ocr):
|
||||||
"""Tests from figure 3 in the paper."""
|
"""Tests from figure 3 in the paper."""
|
||||||
gt = "\"I have\nno special\ntalent." \
|
gt = (
|
||||||
"\nI am only\npassionately\ncurious.\"" \
|
'"I have\nno special\ntalent.\n'
|
||||||
"\nAlbert\nEinstein"
|
'I am only\npassionately\ncurious."\n'
|
||||||
|
"Albert\nEinstein"
|
||||||
|
)
|
||||||
replacements, inserts, deletes = 3, 5, 7
|
replacements, inserts, deletes = 3, 5, 7
|
||||||
chars = len(gt) - gt.count("\n")
|
chars = len(gt) - gt.count("\n")
|
||||||
assert chars == 68
|
assert chars == 68
|
||||||
|
@ -127,21 +148,27 @@ def test_flexible_character_accuracy(config, ocr):
|
||||||
replacements += 1
|
replacements += 1
|
||||||
deletes -= 1
|
deletes -= 1
|
||||||
|
|
||||||
expected_dist = Distance(match=chars - deletes - replacements, replace=replacements,
|
expected_dist = Distance(
|
||||||
insert=inserts, delete=deletes)
|
match=chars - deletes - replacements,
|
||||||
|
replace=replacements,
|
||||||
|
insert=inserts,
|
||||||
|
delete=deletes,
|
||||||
|
)
|
||||||
expected_score = character_accuracy(expected_dist)
|
expected_score = character_accuracy(expected_dist)
|
||||||
|
|
||||||
result, matches = flexible_character_accuracy(gt, ocr)
|
result, matches = flexible_character_accuracy(gt, ocr)
|
||||||
agg = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
|
agg = reduce(
|
||||||
matches, Counter())
|
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
|
||||||
|
)
|
||||||
dist = Distance(**agg)
|
dist = Distance(**agg)
|
||||||
assert dist == expected_dist
|
assert dist == expected_dist
|
||||||
assert result == pytest.approx(expected_score, abs=0.0005)
|
assert result == pytest.approx(expected_score, abs=0.0005)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES)
|
@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES)
|
||||||
def test_flexible_character_accuracy_extended(gt, ocr, first_line_score,
|
def test_flexible_character_accuracy_extended(
|
||||||
all_line_score):
|
gt, ocr, first_line_score, all_line_score
|
||||||
|
):
|
||||||
"""Tests from figure 4 in the paper."""
|
"""Tests from figure 4 in the paper."""
|
||||||
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
|
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
|
||||||
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
|
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
|
||||||
|
@ -170,10 +197,13 @@ def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score):
|
||||||
assert score == pytest.approx(first_line_score)
|
assert score == pytest.approx(first_line_score)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(CASE_ARGS, [
|
@pytest.mark.parametrize(
|
||||||
|
CASE_ARGS,
|
||||||
|
[
|
||||||
*SIMPLE_CASES,
|
*SIMPLE_CASES,
|
||||||
("accc", "a\nbb\nccc", 1.0, 1.0),
|
("accc", "a\nbb\nccc", 1.0, 1.0),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
|
def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
|
||||||
coef = Coefficients()
|
coef = Coefficients()
|
||||||
gt_lines = initialize_lines(gt)
|
gt_lines = initialize_lines(gt)
|
||||||
|
@ -185,15 +215,22 @@ def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
|
||||||
assert score == pytest.approx(first_line_score)
|
assert score == pytest.approx(first_line_score)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("original,match,expected_lines", [
|
@pytest.mark.parametrize(
|
||||||
|
"original,match,expected_lines",
|
||||||
|
[
|
||||||
(Part(), Part(), []),
|
(Part(), Part(), []),
|
||||||
(Part(text="abc"), Part(), [Part(text="abc")]),
|
(Part(text="abc"), Part(), [Part(text="abc")]),
|
||||||
(Part(text="abc"), Part("d"), [Part(text="bc", start=1)]),
|
(Part(text="abc"), Part("d"), [Part(text="bc", start=1)]),
|
||||||
(Part(text="abc"), Part("a", start=100), [Part(text="abc")]),
|
(Part(text="abc"), Part("a", start=100), [Part(text="abc")]),
|
||||||
(Part(text="abc"), Part("a"), [Part(text="bc", start=1)]),
|
(Part(text="abc"), Part("a"), [Part(text="bc", start=1)]),
|
||||||
(Part(text="abc"), Part("b", start=1), [Part(text="a"), Part(text="c", start=2)]),
|
(
|
||||||
|
Part(text="abc"),
|
||||||
|
Part("b", start=1),
|
||||||
|
[Part(text="a"), Part(text="c", start=2)],
|
||||||
|
),
|
||||||
(Part(text="abc"), Part("c", start=2), [Part(text="ab")]),
|
(Part(text="abc"), Part("c", start=2), [Part(text="ab")]),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_remove_or_split(original, match, expected_lines):
|
def test_remove_or_split(original, match, expected_lines):
|
||||||
lines = [original]
|
lines = [original]
|
||||||
splitted = remove_or_split(original, match, lines)
|
splitted = remove_or_split(original, match, lines)
|
||||||
|
@ -201,7 +238,9 @@ def test_remove_or_split(original, match, expected_lines):
|
||||||
assert lines == expected_lines
|
assert lines == expected_lines
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(EDIT_ARGS, [
|
@pytest.mark.parametrize(
|
||||||
|
EDIT_ARGS,
|
||||||
|
[
|
||||||
*SIMPLE_EDITS,
|
*SIMPLE_EDITS,
|
||||||
(Part(text="a"), Part(text="b"), Distance(delete=1)),
|
(Part(text="a"), Part(text="b"), Distance(delete=1)),
|
||||||
(Part(text="aaa"), Part(text="bbb"), Distance(delete=3)),
|
(Part(text="aaa"), Part(text="bbb"), Distance(delete=3)),
|
||||||
|
@ -210,9 +249,18 @@ def test_remove_or_split(original, match, expected_lines):
|
||||||
(Part(text=""), Part(text=""), None),
|
(Part(text=""), Part(text=""), None),
|
||||||
(Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)),
|
(Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)),
|
||||||
(Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)),
|
(Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)),
|
||||||
(Part(text="aaabbbaaaddd"), Part(text="aaabcbaaa"), Distance(match=8, replace=1)),
|
(
|
||||||
(Part(text="aaabbbccc"), Part(text="aaabbbdddccc"), Distance(match=9, insert=3)),
|
Part(text="aaabbbaaaddd"),
|
||||||
])
|
Part(text="aaabcbaaa"),
|
||||||
|
Distance(match=8, replace=1),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Part(text="aaabbbccc"),
|
||||||
|
Part(text="aaabbbdddccc"),
|
||||||
|
Distance(match=9, insert=3),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_match_lines(gt, ocr, expected_dist):
|
def test_match_lines(gt, ocr, expected_dist):
|
||||||
match = match_lines(gt, ocr)
|
match = match_lines(gt, ocr)
|
||||||
if not expected_dist:
|
if not expected_dist:
|
||||||
|
@ -223,14 +271,17 @@ def test_match_lines(gt, ocr, expected_dist):
|
||||||
assert match.dist == expected_dist
|
assert match.dist == expected_dist
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(EDIT_ARGS, [
|
@pytest.mark.parametrize(
|
||||||
|
EDIT_ARGS,
|
||||||
|
[
|
||||||
*SIMPLE_EDITS,
|
*SIMPLE_EDITS,
|
||||||
(Part(text="").substring(), Part(text=""), Distance()),
|
(Part(text="").substring(), Part(text=""), Distance()),
|
||||||
(Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)),
|
(Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)),
|
||||||
(Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)),
|
(Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)),
|
||||||
(Part(text="a"), Part(text="b"), Distance(replace=1)),
|
(Part(text="a"), Part(text="b"), Distance(replace=1)),
|
||||||
(Part(text="aaa"), Part(text="bbb"), Distance(replace=3)),
|
(Part(text="aaa"), Part(text="bbb"), Distance(replace=3)),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_distance(gt, ocr, expected_dist):
|
def test_distance(gt, ocr, expected_dist):
|
||||||
match = distance(gt, ocr)
|
match = distance(gt, ocr)
|
||||||
assert match.gt == gt
|
assert match.gt == gt
|
||||||
|
@ -238,20 +289,37 @@ def test_distance(gt, ocr, expected_dist):
|
||||||
assert match.dist == expected_dist
|
assert match.dist == expected_dist
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("matches,expected_dist", [
|
@pytest.mark.parametrize(
|
||||||
|
"matches,expected_dist",
|
||||||
|
[
|
||||||
([], 1),
|
([], 1),
|
||||||
([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1),
|
([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1),
|
||||||
([Match(gt=Part(text="abee"), ocr=Part("ac"),
|
(
|
||||||
dist=Distance(match=1, replace=1, delete=2), ops=[]),
|
[
|
||||||
Match(gt=Part(text="cd"), ocr=Part("ceff"),
|
Match(
|
||||||
dist=Distance(match=1, replace=1, insert=2), ops=[])],
|
gt=Part(text="abee"),
|
||||||
1 - 6 / 6),
|
ocr=Part("ac"),
|
||||||
])
|
dist=Distance(match=1, replace=1, delete=2),
|
||||||
|
ops=[],
|
||||||
|
),
|
||||||
|
Match(
|
||||||
|
gt=Part(text="cd"),
|
||||||
|
ocr=Part("ceff"),
|
||||||
|
dist=Distance(match=1, replace=1, insert=2),
|
||||||
|
ops=[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
1 - 6 / 6,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_character_accuracy_matches(matches, expected_dist):
|
def test_character_accuracy_matches(matches, expected_dist):
|
||||||
assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist)
|
assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dist,expected_dist", [
|
@pytest.mark.parametrize(
|
||||||
|
"dist,expected_dist",
|
||||||
|
[
|
||||||
(Distance(), 1),
|
(Distance(), 1),
|
||||||
(Distance(match=1), 1),
|
(Distance(match=1), 1),
|
||||||
(Distance(replace=1), 0),
|
(Distance(replace=1), 0),
|
||||||
|
@ -259,25 +327,39 @@ def test_character_accuracy_matches(matches, expected_dist):
|
||||||
(Distance(match=1, insert=2), 1 - 2 / 1),
|
(Distance(match=1, insert=2), 1 - 2 / 1),
|
||||||
(Distance(match=2, insert=1), 0.5),
|
(Distance(match=2, insert=1), 0.5),
|
||||||
(Distance(match=1, delete=1), 0.5),
|
(Distance(match=1, delete=1), 0.5),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_character_accuracy_dist(dist, expected_dist):
|
def test_character_accuracy_dist(dist, expected_dist):
|
||||||
assert character_accuracy(dist) == pytest.approx(expected_dist)
|
assert character_accuracy(dist) == pytest.approx(expected_dist)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("line,subline,expected_rest", [
|
@pytest.mark.parametrize(
|
||||||
|
"line,subline,expected_rest",
|
||||||
|
[
|
||||||
(Part(), Part(), []),
|
(Part(), Part(), []),
|
||||||
(Part("aaa bbb"), Part("aaa bbb"), []),
|
(Part("aaa bbb"), Part("aaa bbb"), []),
|
||||||
(Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]),
|
(Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]),
|
||||||
(Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]),
|
(Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]),
|
||||||
(Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]),
|
(Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]),
|
||||||
(Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]),
|
(Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]),
|
||||||
(Part("aaa bbb ccc"), Part("bbb", start=4), [Part("aaa "), Part(" ccc", start=7)]),
|
(
|
||||||
(Part("aaa bbb ccc", start=3), Part("bbb", start=7),
|
Part("aaa bbb ccc"),
|
||||||
[Part("aaa ", start=3), Part(" ccc", start=10)]),
|
Part("bbb", start=4),
|
||||||
|
[Part("aaa "), Part(" ccc", start=7)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
Part("aaa bbb ccc", start=3),
|
||||||
|
Part("bbb", start=7),
|
||||||
|
[Part("aaa ", start=3), Part(" ccc", start=10)],
|
||||||
|
),
|
||||||
(Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]),
|
(Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]),
|
||||||
(Part("aaa bbb", start=3), Part(" ", start=6),
|
(
|
||||||
[Part("aaa", start=3), Part("bbb", start=7)]),
|
Part("aaa bbb", start=3),
|
||||||
])
|
Part(" ", start=6),
|
||||||
|
[Part("aaa", start=3), Part("bbb", start=7)],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_split_line(line, subline, expected_rest):
|
def test_split_line(line, subline, expected_rest):
|
||||||
rest = line.split(subline)
|
rest = line.split(subline)
|
||||||
assert len(rest) == len(expected_rest)
|
assert len(rest) == len(expected_rest)
|
||||||
|
@ -300,12 +382,15 @@ def test_combine_lines():
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("line,start,end,expected", [
|
@pytest.mark.parametrize(
|
||||||
|
"line,start,end,expected",
|
||||||
|
[
|
||||||
(Part(text=""), 0, None, Part(text="")),
|
(Part(text=""), 0, None, Part(text="")),
|
||||||
(Part(text="a"), 0, None, Part(text="a")),
|
(Part(text="a"), 0, None, Part(text="a")),
|
||||||
(Part(text="ab"), 0, 1, Part(text="a")),
|
(Part(text="ab"), 0, 1, Part(text="a")),
|
||||||
(Part(text="abc"), 0, -1, Part(text="ab")),
|
(Part(text="abc"), 0, -1, Part(text="ab")),
|
||||||
(Part(text="ab"), 1, None, Part(text="b", start=1)),
|
(Part(text="ab"), 1, None, Part(text="b", start=1)),
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_line_substring(line, start, end, expected):
|
def test_line_substring(line, start, end, expected):
|
||||||
assert line.substring(rel_start=start, rel_end=end) == expected
|
assert line.substring(rel_start=start, rel_end=end) == expected
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue