Reformat using black

pull/47/head
Benjamin Rosemann 5 years ago
parent 5277593bdb
commit 2a215a1062

@ -29,20 +29,16 @@ def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"]
:return: Score between 0 and 1 and match objects.
"""
best_score = -float('inf')
best_score = -float("inf")
best_matches = []
# TODO: this should be configurable
combinations = product(range(15, 31, 5),
range(0, 24, 3),
range(0, 4, 1),
range(0, 6, 1))
combinations = product(
range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1)
)
# TODO: place to parallelize the algorithm
for (edit_dist, length_diff, offset, length) in combinations:
coef = Coefficients(
edit_dist=edit_dist,
length_diff=length_diff,
offset=offset,
length=length
edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length
)
# Steps 1 - 6 of the flexible character accuracy algorithm.
matches = match_with_coefficients(gt, ocr, coef)
@ -79,17 +75,21 @@ def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Ma
# Step 6 of the flexible character accuracy algorithm.
# remaining lines are considered as deletes and inserts
deletes = [distance(line, Part(text="", line=line.line, start=line.start))
for line in gt_lines]
inserts = [distance(Part(text="", line=line.line, start=line.start), line)
for line in ocr_lines]
deletes = [
distance(line, Part(text="", line=line.line, start=line.start))
for line in gt_lines
]
inserts = [
distance(Part(text="", line=line.line, start=line.start), line)
for line in ocr_lines
]
return [*matches, *deletes, *inserts]
def match_longest_gt_lines(gt_lines: List["Part"],
ocr_lines: List["Part"],
coef: "Coefficients") -> Optional["Match"]:
def match_longest_gt_lines(
gt_lines: List["Part"], ocr_lines: List["Part"], coef: "Coefficients"
) -> Optional["Match"]:
"""Find the best match for the longest line(s) in ground truth.
The longest lines in ground truth are matched against lines in ocr to find the
@ -99,7 +99,7 @@ def match_longest_gt_lines(gt_lines: List["Part"],
:return: Possible match object.
"""
best_score, best_match, best_gt, best_ocr = -float('inf'), None, None, None
best_score, best_match, best_gt, best_ocr = -float("inf"), None, None, None
if not ocr_lines:
return best_match
@ -126,10 +126,9 @@ def match_longest_gt_lines(gt_lines: List["Part"],
return best_match
def match_gt_line(gt_line: "Part",
ocr_lines: List["Part"],
coef: "Coefficients") -> Tuple[Optional["Match"],
Optional["Part"]]:
def match_gt_line(
gt_line: "Part", ocr_lines: List["Part"], coef: "Coefficients"
) -> Tuple[Optional["Match"], Optional["Part"]]:
"""Match the given ground truth line against all the lines in ocr.
Reference: contains steps 3 of the flexible character accuracy algorithm.
@ -138,19 +137,18 @@ def match_gt_line(gt_line: "Part",
:return: Match object and the matched ocr line.
"""
min_penalty = float('inf')
min_penalty = float("inf")
best_match, best_ocr = None, None
for ocr_line in [*ocr_lines]:
match = match_lines(gt_line, ocr_line)
if match:
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
if penalty < min_penalty:
min_penalty, best_match, best_ocr = penalty, match, ocr_line
return best_match, best_ocr
def remove_or_split(original: "Part",
match: "Part",
lines: List["Part"]) -> bool:
def remove_or_split(original: "Part", match: "Part", lines: List["Part"]) -> bool:
"""Removes the matched line or splits it into parts.
Reference: contains step 4 of the flexible character accuracy algorithm.
@ -187,17 +185,24 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
if min_length == 0:
return best_match
length_diff = gt_line.length - ocr_line.length
min_edit_dist = float('inf')
min_edit_dist = float("inf")
gt_parts = [(i, gt_line.substring(rel_start=i, rel_end=i + min_length))
for i in range(0, max(1, length_diff + 1))]
ocr_parts = [(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
for j in range(0, max(1, -1 * length_diff + 1))]
gt_parts = [
(i, gt_line.substring(rel_start=i, rel_end=i + min_length))
for i in range(0, max(1, length_diff + 1))
]
ocr_parts = [
(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
for j in range(0, max(1, -1 * length_diff + 1))
]
# add full line and empty line match
gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)]
ocr_parts = [*ocr_parts, (0, ocr_line),
(0, Part(text="", line=gt_line.line, start=gt_line.start))]
ocr_parts = [
*ocr_parts,
(0, ocr_line),
(0, Part(text="", line=gt_line.line, start=gt_line.start)),
]
for i, gt_part in gt_parts:
for j, ocr_part in ocr_parts:
@ -211,8 +216,10 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
part_length = best_match.gt.length
additional_length = best_match.dist.delete + best_match.dist.replace
for k in range(part_length + 1, part_length + additional_length + 1):
match = distance(gt_line.substring(rel_start=best_i, rel_end=best_i + k),
ocr_line.substring(rel_start=best_j, rel_end=best_j + k))
match = distance(
gt_line.substring(rel_start=best_i, rel_end=best_i + k),
ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
)
edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist:
min_edit_dist = edit_dist
@ -247,8 +254,9 @@ def score_edit_distance(match: "Match") -> int:
return match.dist.delete + match.dist.insert + 2 * match.dist.replace
def calculate_penalty(gt: "Part", ocr: "Part", match: "Match",
coef: "Coefficients") -> float:
def calculate_penalty(
gt: "Part", ocr: "Part", match: "Match", coef: "Coefficients"
) -> float:
"""Calculate the penalty for a given match.
For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003.
@ -262,10 +270,12 @@ def calculate_penalty(gt: "Part", ocr: "Part", match: "Match",
if length_diff > 1:
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
return (min_edit_dist * coef.edit_dist
return (
min_edit_dist * coef.edit_dist
+ length_diff * coef.length_diff
+ offset * coef.offset
- substring_length * coef.length)
- substring_length * coef.length
)
def character_accuracy_for_matches(matches: List["Match"]) -> float:
@ -274,8 +284,9 @@ def character_accuracy_for_matches(matches: List["Match"]) -> float:
See other `character_accuracy` for details.
"""
agg: Counter = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
matches, Counter())
agg: Counter = reduce(
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
)
score = character_accuracy(Distance(**agg))
return score
@ -299,9 +310,9 @@ def character_accuracy(edits: "Distance") -> float:
chars = edits.match + edits.replace + edits.delete
if not chars and not errors:
# comparison of empty strings is considered a full match
score = 1
score = 1.0
else:
score = 1 - errors / chars
score = 1.0 - errors / chars
return score
@ -315,9 +326,11 @@ def initialize_lines(text: str) -> List["Part"]:
:param text: Text to split into lines.
:return: List of sorted line objects.
"""
lines = [Part(text=line, line=i, start=0)
lines = [
Part(text=line, line=i, start=0)
for i, line in enumerate(text.splitlines())
if len(line) > 0]
if len(line) > 0
]
lines.sort(key=lambda x: x.length, reverse=True)
return lines
@ -348,6 +361,7 @@ class Part(NamedTuple):
This data object is maintained to be able to reproduce the original text.
"""
text: str = ""
line: int = 0
start: int = 0
@ -392,6 +406,7 @@ class Part(NamedTuple):
class Distance(NamedTuple):
"""Represent distance between two sequences."""
match: int = 0
replace: int = 0
delete: int = 0
@ -400,6 +415,7 @@ class Distance(NamedTuple):
class Match(NamedTuple):
"""Represent a calculated match between ground truth and the ocr result."""
gt: "Part"
ocr: "Part"
dist: "Distance"
@ -411,6 +427,7 @@ class Coefficients(NamedTuple):
See Section 3 in doi:10.1016/j.patrec.2020.02.003
"""
edit_dist: int = 25
length_diff: int = 20
offset: int = 1

@ -35,37 +35,31 @@ COMPLEX_CASES = [
EXTENDED_CASES = [
# See figure 4 in 10.1016/j.patrec.2020.02.003
# A: No errors
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
1, 1),
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), 1, 1),
# B: Different ordering of text blocks
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4),
1, 1),
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), 1, 1),
# C: Merge across columns
((0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9),
(
(0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9),
(0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9),
1, 0.964),
1,
0.964,
),
# D: Over-segmentation
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9),
1, 0.966),
1,
0.966,
),
# E: Part missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 3, 4),
1, 0.50),
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4), 1, 0.50),
# E.2: Part missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(5, 6, 7, 8, 9),
1, 0.50),
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9), 1, 0.50),
# F: All missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(),
1, 0),
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (), 1, 0),
# G: Added parts
((0, 1, 2, 3, 4),
(0, 1, 2, 3, 4, 11, 5, 6),
1, 0.621),
((0, 1, 2, 3, 4), (0, 1, 2, 3, 4, 11, 5, 6), 1, 0.621),
]
EDIT_ARGS = "gt,ocr,expected_dist"
@ -73,8 +67,11 @@ EDIT_ARGS = "gt,ocr,expected_dist"
SIMPLE_EDITS = [
(Part(text="a"), Part(text="a"), Distance(match=1)),
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
(Part(text="abcd"), Part(text="beed"),
Distance(match=2, replace=1, insert=1, delete=1)),
(
Part(text="abcd"),
Part(text="beed"),
Distance(match=2, replace=1, insert=1, delete=1),
),
]
@ -83,9 +80,20 @@ def extended_case_to_text(gt, ocr):
See figure 4 in 10.1016/j.patrec.2020.02.003
"""
sentence = ("Eight", "happy", "frogs", "scuba", "dived",
"Jenny", "chick", "flaps", "white", "wings",
"", "\n")
sentence = (
"Eight",
"happy",
"frogs",
"scuba",
"dived",
"Jenny",
"chick",
"flaps",
"white",
"wings",
"",
"\n",
)
gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n")
ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n")
@ -98,22 +106,35 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_
assert score == pytest.approx(all_line_score)
@pytest.mark.parametrize("config,ocr", [
("Config I",
"1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein"
@pytest.mark.parametrize(
"config,ocr",
[
(
"Config I",
"1 hav\nnospecial\ntalents.\n"
'I am one\npassionate\ncuriousity."\n'
"Alberto\nEmstein",
),
("Config II",
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\""
(
"Config II",
'1 hav\nnospecial\ntalents. Alberto\n'
'I am one Emstein\npassionate\ncuriousity."',
),
("Config III",
"Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\""
(
"Config III",
'Alberto\nEmstein\n'
'1 hav\nnospecial\ntalents.\n'
'I am one\npassionate\ncuriousity."',
),
])
],
)
def test_flexible_character_accuracy(config, ocr):
"""Tests from figure 3 in the paper."""
gt = "\"I have\nno special\ntalent." \
"\nI am only\npassionately\ncurious.\"" \
"\nAlbert\nEinstein"
gt = (
'"I have\nno special\ntalent.\n'
'I am only\npassionately\ncurious."\n'
"Albert\nEinstein"
)
replacements, inserts, deletes = 3, 5, 7
chars = len(gt) - gt.count("\n")
assert chars == 68
@ -127,21 +148,27 @@ def test_flexible_character_accuracy(config, ocr):
replacements += 1
deletes -= 1
expected_dist = Distance(match=chars - deletes - replacements, replace=replacements,
insert=inserts, delete=deletes)
expected_dist = Distance(
match=chars - deletes - replacements,
replace=replacements,
insert=inserts,
delete=deletes,
)
expected_score = character_accuracy(expected_dist)
result, matches = flexible_character_accuracy(gt, ocr)
agg = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
matches, Counter())
agg = reduce(
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
)
dist = Distance(**agg)
assert dist == expected_dist
assert result == pytest.approx(expected_score, abs=0.0005)
@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES)
def test_flexible_character_accuracy_extended(gt, ocr, first_line_score,
all_line_score):
def test_flexible_character_accuracy_extended(
gt, ocr, first_line_score, all_line_score
):
"""Tests from figure 4 in the paper."""
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
@ -170,10 +197,13 @@ def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score):
assert score == pytest.approx(first_line_score)
@pytest.mark.parametrize(CASE_ARGS, [
@pytest.mark.parametrize(
CASE_ARGS,
[
*SIMPLE_CASES,
("accc", "a\nbb\nccc", 1.0, 1.0),
])
],
)
def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
coef = Coefficients()
gt_lines = initialize_lines(gt)
@ -185,15 +215,22 @@ def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
assert score == pytest.approx(first_line_score)
@pytest.mark.parametrize("original,match,expected_lines", [
@pytest.mark.parametrize(
"original,match,expected_lines",
[
(Part(), Part(), []),
(Part(text="abc"), Part(), [Part(text="abc")]),
(Part(text="abc"), Part("d"), [Part(text="bc", start=1)]),
(Part(text="abc"), Part("a", start=100), [Part(text="abc")]),
(Part(text="abc"), Part("a"), [Part(text="bc", start=1)]),
(Part(text="abc"), Part("b", start=1), [Part(text="a"), Part(text="c", start=2)]),
(
Part(text="abc"),
Part("b", start=1),
[Part(text="a"), Part(text="c", start=2)],
),
(Part(text="abc"), Part("c", start=2), [Part(text="ab")]),
])
],
)
def test_remove_or_split(original, match, expected_lines):
lines = [original]
splitted = remove_or_split(original, match, lines)
@ -201,7 +238,9 @@ def test_remove_or_split(original, match, expected_lines):
assert lines == expected_lines
@pytest.mark.parametrize(EDIT_ARGS, [
@pytest.mark.parametrize(
EDIT_ARGS,
[
*SIMPLE_EDITS,
(Part(text="a"), Part(text="b"), Distance(delete=1)),
(Part(text="aaa"), Part(text="bbb"), Distance(delete=3)),
@ -210,9 +249,18 @@ def test_remove_or_split(original, match, expected_lines):
(Part(text=""), Part(text=""), None),
(Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)),
(Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)),
(Part(text="aaabbbaaaddd"), Part(text="aaabcbaaa"), Distance(match=8, replace=1)),
(Part(text="aaabbbccc"), Part(text="aaabbbdddccc"), Distance(match=9, insert=3)),
])
(
Part(text="aaabbbaaaddd"),
Part(text="aaabcbaaa"),
Distance(match=8, replace=1),
),
(
Part(text="aaabbbccc"),
Part(text="aaabbbdddccc"),
Distance(match=9, insert=3),
),
],
)
def test_match_lines(gt, ocr, expected_dist):
match = match_lines(gt, ocr)
if not expected_dist:
@ -223,14 +271,17 @@ def test_match_lines(gt, ocr, expected_dist):
assert match.dist == expected_dist
@pytest.mark.parametrize(EDIT_ARGS, [
@pytest.mark.parametrize(
EDIT_ARGS,
[
*SIMPLE_EDITS,
(Part(text="").substring(), Part(text=""), Distance()),
(Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)),
(Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)),
(Part(text="a"), Part(text="b"), Distance(replace=1)),
(Part(text="aaa"), Part(text="bbb"), Distance(replace=3)),
])
],
)
def test_distance(gt, ocr, expected_dist):
match = distance(gt, ocr)
assert match.gt == gt
@ -238,20 +289,37 @@ def test_distance(gt, ocr, expected_dist):
assert match.dist == expected_dist
@pytest.mark.parametrize("matches,expected_dist", [
@pytest.mark.parametrize(
"matches,expected_dist",
[
([], 1),
([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1),
([Match(gt=Part(text="abee"), ocr=Part("ac"),
dist=Distance(match=1, replace=1, delete=2), ops=[]),
Match(gt=Part(text="cd"), ocr=Part("ceff"),
dist=Distance(match=1, replace=1, insert=2), ops=[])],
1 - 6 / 6),
])
(
[
Match(
gt=Part(text="abee"),
ocr=Part("ac"),
dist=Distance(match=1, replace=1, delete=2),
ops=[],
),
Match(
gt=Part(text="cd"),
ocr=Part("ceff"),
dist=Distance(match=1, replace=1, insert=2),
ops=[],
),
],
1 - 6 / 6,
),
],
)
def test_character_accuracy_matches(matches, expected_dist):
assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist)
@pytest.mark.parametrize("dist,expected_dist", [
@pytest.mark.parametrize(
"dist,expected_dist",
[
(Distance(), 1),
(Distance(match=1), 1),
(Distance(replace=1), 0),
@ -259,25 +327,39 @@ def test_character_accuracy_matches(matches, expected_dist):
(Distance(match=1, insert=2), 1 - 2 / 1),
(Distance(match=2, insert=1), 0.5),
(Distance(match=1, delete=1), 0.5),
])
],
)
def test_character_accuracy_dist(dist, expected_dist):
assert character_accuracy(dist) == pytest.approx(expected_dist)
@pytest.mark.parametrize("line,subline,expected_rest", [
@pytest.mark.parametrize(
"line,subline,expected_rest",
[
(Part(), Part(), []),
(Part("aaa bbb"), Part("aaa bbb"), []),
(Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]),
(Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]),
(Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]),
(Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]),
(Part("aaa bbb ccc"), Part("bbb", start=4), [Part("aaa "), Part(" ccc", start=7)]),
(Part("aaa bbb ccc", start=3), Part("bbb", start=7),
[Part("aaa ", start=3), Part(" ccc", start=10)]),
(
Part("aaa bbb ccc"),
Part("bbb", start=4),
[Part("aaa "), Part(" ccc", start=7)],
),
(
Part("aaa bbb ccc", start=3),
Part("bbb", start=7),
[Part("aaa ", start=3), Part(" ccc", start=10)],
),
(Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]),
(Part("aaa bbb", start=3), Part(" ", start=6),
[Part("aaa", start=3), Part("bbb", start=7)]),
])
(
Part("aaa bbb", start=3),
Part(" ", start=6),
[Part("aaa", start=3), Part("bbb", start=7)],
),
],
)
def test_split_line(line, subline, expected_rest):
rest = line.split(subline)
assert len(rest) == len(expected_rest)
@ -300,12 +382,15 @@ def test_combine_lines():
assert False
@pytest.mark.parametrize("line,start,end,expected", [
@pytest.mark.parametrize(
"line,start,end,expected",
[
(Part(text=""), 0, None, Part(text="")),
(Part(text="a"), 0, None, Part(text="a")),
(Part(text="ab"), 0, 1, Part(text="a")),
(Part(text="abc"), 0, -1, Part(text="ab")),
(Part(text="ab"), 1, None, Part(text="b", start=1)),
])
],
)
def test_line_substring(line, start, end, expected):
assert line.substring(rel_start=start, rel_end=end) == expected

Loading…
Cancel
Save