Fix some special cases

pull/47/head
Benjamin Rosemann 5 years ago
parent d7a74fa58b
commit 5277593bdb

@ -117,6 +117,7 @@ def match_longest_gt_lines(gt_lines: List["Part"],
if best_match and best_gt:
splitted = remove_or_split(best_gt, best_match.gt, gt_lines)
if splitted:
# according to the paper the match is not put back, we deviate...
gt_lines.append(best_match.gt)
best_match = None
if best_match and best_ocr:
@ -134,13 +135,12 @@ def match_gt_line(gt_line: "Part",
Reference: contains steps 3 of the flexible character accuracy algorithm.
TODO: Make penalty function configurable?
TODO: Add empty ocr line to avoid having nonesense one character alignments?
:return: Match object and the matched ocr line.
"""
min_penalty = float('inf')
best_match, best_ocr = None, None
for ocr_line in ocr_lines:
for ocr_line in [*ocr_lines]:
match = match_lines(gt_line, ocr_line)
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
if penalty < min_penalty:
@ -177,20 +177,42 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
Reference: see figure 2 in the paper.
TODO: make distance function configurable?
TODO: rethink @lru_cache
:return: Match object if one is found.
"""
min_length = min(gt_line.length, ocr_line.length)
best_match = None
best_i, best_j = 0, 0
if min_length == 0:
return best_match
length_diff = gt_line.length - ocr_line.length
min_edit_dist = float('inf')
# TODO: handle deletes and replacements by extending the length.
for i in range(0, max(1, length_diff + 1)):
for j in range(0, max(1, -1 * length_diff + 1)):
match = distance(gt_line.substring(rel_start=i, rel_end=i + min_length),
ocr_line.substring(rel_start=j, rel_end=j + min_length))
gt_parts = [(i, gt_line.substring(rel_start=i, rel_end=i + min_length))
for i in range(0, max(1, length_diff + 1))]
ocr_parts = [(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
for j in range(0, max(1, -1 * length_diff + 1))]
# add full line and empty line match
gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)]
ocr_parts = [*ocr_parts, (0, ocr_line),
(0, Part(text="", line=gt_line.line, start=gt_line.start))]
for i, gt_part in gt_parts:
for j, ocr_part in ocr_parts:
match = distance(gt_part, ocr_part)
edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist:
min_edit_dist = edit_dist
best_match = match
best_i, best_j = i, j
if best_match and (best_match.dist.delete or best_match.dist.replace):
part_length = best_match.gt.length
additional_length = best_match.dist.delete + best_match.dist.replace
for k in range(part_length + 1, part_length + additional_length + 1):
match = distance(gt_line.substring(rel_start=best_i, rel_end=best_i + k),
ocr_line.substring(rel_start=best_j, rel_end=best_j + k))
edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist:
min_edit_dist = edit_dist
@ -205,6 +227,7 @@ def distance(gt: "Part", ocr: "Part") -> "Match":
Using the already available `editops()` function with the Levenshtein distance.
TODO: replace with @cache annotation in Python 3.9
TODO: rethink @lru_cache
:return: Match object containing the lines and the editing operations.
"""

@ -33,6 +33,7 @@ COMPLEX_CASES = [
]
EXTENDED_CASES = [
# See figure 4 in 10.1016/j.patrec.2020.02.003
# A: No errors
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
@ -70,14 +71,18 @@ EXTENDED_CASES = [
EDIT_ARGS = "gt,ocr,expected_dist"
SIMPLE_EDITS = [
(Part(text="a").substring(), Part(text="a"), Distance(match=1)),
(Part(text="a").substring(), Part(text="b"), Distance(replace=1)),
(Part(text="abcd").substring(), Part(text="beed"),
(Part(text="a"), Part(text="a"), Distance(match=1)),
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
(Part(text="abcd"), Part(text="beed"),
Distance(match=2, replace=1, insert=1, delete=1)),
]
def extended_case_to_text(gt, ocr):
"""Generate sentence from reading order encoding.
See figure 4 in 10.1016/j.patrec.2020.02.003
"""
sentence = ("Eight", "happy", "frogs", "scuba", "dived",
"Jenny", "chick", "flaps", "white", "wings",
"", "\n")
@ -93,38 +98,45 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_
assert score == pytest.approx(all_line_score)
@pytest.mark.xfail
@pytest.mark.parametrize("ocr", [
"1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein",
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\"",
"Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\""
@pytest.mark.parametrize("config,ocr", [
("Config I",
"1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein"
),
("Config II",
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\""
),
("Config III",
"Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\""
),
])
def test_flexible_character_accuracy(ocr):
"""Tests from figure 3 in the paper.
TODO: We have a 2 percent deviation from the original because of redistributed
one character alignments (e.g. the y-insert replaces the y-delete).
"""
gt = """"I have
no special
talent.
I am only
passionately
curious."
Albert
Einstein
"""
replacements = 3
inserts = 5
deletes = 7
def test_flexible_character_accuracy(config, ocr):
"""Tests from figure 3 in the paper."""
gt = "\"I have\nno special\ntalent." \
"\nI am only\npassionately\ncurious.\"" \
"\nAlbert\nEinstein"
replacements, inserts, deletes = 3, 5, 7
chars = len(gt) - gt.count("\n")
assert replacements + inserts + deletes == 15
edits = Distance(match=chars - deletes - replacements, replace=replacements,
insert=inserts, delete=deletes)
expected = character_accuracy(edits)
assert expected == pytest.approx(0.779, abs=0.0005)
assert chars == 68
# We consider whitespace as error and in Config II two additional
# whitespaces have been introduced. One will be counted as insert.
# The other whitespace will be counted as replacement,
# additionally reducing the number of deletes.
if config == "Config II":
inserts += 1
replacements += 1
deletes -= 1
expected_dist = Distance(match=chars - deletes - replacements, replace=replacements,
insert=inserts, delete=deletes)
expected_score = character_accuracy(expected_dist)
result, matches = flexible_character_accuracy(gt, ocr)
assert result == pytest.approx(expected, abs=0.0005)
agg = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
matches, Counter())
dist = Distance(**agg)
assert dist == expected_dist
assert result == pytest.approx(expected_score, abs=0.0005)
@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES)
@ -191,9 +203,15 @@ def test_remove_or_split(original, match, expected_lines):
@pytest.mark.parametrize(EDIT_ARGS, [
*SIMPLE_EDITS,
(Part(text="a"), Part(text="b"), Distance(delete=1)),
(Part(text="aaa"), Part(text="bbb"), Distance(delete=3)),
(Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)),
(Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)),
(Part(text=""), Part(text=""), None)
(Part(text=""), Part(text=""), None),
(Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)),
(Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)),
(Part(text="aaabbbaaaddd"), Part(text="aaabcbaaa"), Distance(match=8, replace=1)),
(Part(text="aaabbbccc"), Part(text="aaabbbdddccc"), Distance(match=9, insert=3)),
])
def test_match_lines(gt, ocr, expected_dist):
match = match_lines(gt, ocr)
@ -210,6 +228,8 @@ def test_match_lines(gt, ocr, expected_dist):
(Part(text="").substring(), Part(text=""), Distance()),
(Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)),
(Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)),
(Part(text="a"), Part(text="b"), Distance(replace=1)),
(Part(text="aaa"), Part(text="bbb"), Distance(replace=3)),
])
def test_distance(gt, ocr, expected_dist):
match = distance(gt, ocr)

Loading…
Cancel
Save