Fix some special cases

pull/47/head
Benjamin Rosemann 5 years ago
parent d7a74fa58b
commit 5277593bdb

@ -117,6 +117,7 @@ def match_longest_gt_lines(gt_lines: List["Part"],
if best_match and best_gt: if best_match and best_gt:
splitted = remove_or_split(best_gt, best_match.gt, gt_lines) splitted = remove_or_split(best_gt, best_match.gt, gt_lines)
if splitted: if splitted:
# according to the paper the match is not put back, we deviate...
gt_lines.append(best_match.gt) gt_lines.append(best_match.gt)
best_match = None best_match = None
if best_match and best_ocr: if best_match and best_ocr:
@ -134,13 +135,12 @@ def match_gt_line(gt_line: "Part",
Reference: contains steps 3 of the flexible character accuracy algorithm. Reference: contains steps 3 of the flexible character accuracy algorithm.
TODO: Make penalty function configurable? TODO: Make penalty function configurable?
TODO: Add empty ocr line to avoid having nonesense one character alignments?
:return: Match object and the matched ocr line. :return: Match object and the matched ocr line.
""" """
min_penalty = float('inf') min_penalty = float('inf')
best_match, best_ocr = None, None best_match, best_ocr = None, None
for ocr_line in ocr_lines: for ocr_line in [*ocr_lines]:
match = match_lines(gt_line, ocr_line) match = match_lines(gt_line, ocr_line)
penalty = calculate_penalty(gt_line, ocr_line, match, coef) penalty = calculate_penalty(gt_line, ocr_line, match, coef)
if penalty < min_penalty: if penalty < min_penalty:
@ -177,20 +177,42 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
Reference: see figure 2 in the paper. Reference: see figure 2 in the paper.
TODO: make distance function configurable? TODO: make distance function configurable?
TODO: rethink @lru_cache
:return: Match object if one is found. :return: Match object if one is found.
""" """
min_length = min(gt_line.length, ocr_line.length) min_length = min(gt_line.length, ocr_line.length)
best_match = None best_match = None
best_i, best_j = 0, 0
if min_length == 0: if min_length == 0:
return best_match return best_match
length_diff = gt_line.length - ocr_line.length length_diff = gt_line.length - ocr_line.length
min_edit_dist = float('inf') min_edit_dist = float('inf')
# TODO: handle deletes and replacements by extending the length.
for i in range(0, max(1, length_diff + 1)): gt_parts = [(i, gt_line.substring(rel_start=i, rel_end=i + min_length))
for j in range(0, max(1, -1 * length_diff + 1)): for i in range(0, max(1, length_diff + 1))]
match = distance(gt_line.substring(rel_start=i, rel_end=i + min_length), ocr_parts = [(j, ocr_line.substring(rel_start=j, rel_end=j + min_length))
ocr_line.substring(rel_start=j, rel_end=j + min_length)) for j in range(0, max(1, -1 * length_diff + 1))]
# add full line and empty line match
gt_parts = [*gt_parts, (0, gt_line), (0, gt_line)]
ocr_parts = [*ocr_parts, (0, ocr_line),
(0, Part(text="", line=gt_line.line, start=gt_line.start))]
for i, gt_part in gt_parts:
for j, ocr_part in ocr_parts:
match = distance(gt_part, ocr_part)
edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist:
min_edit_dist = edit_dist
best_match = match
best_i, best_j = i, j
if best_match and (best_match.dist.delete or best_match.dist.replace):
part_length = best_match.gt.length
additional_length = best_match.dist.delete + best_match.dist.replace
for k in range(part_length + 1, part_length + additional_length + 1):
match = distance(gt_line.substring(rel_start=best_i, rel_end=best_i + k),
ocr_line.substring(rel_start=best_j, rel_end=best_j + k))
edit_dist = score_edit_distance(match) edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist: if edit_dist < min_edit_dist:
min_edit_dist = edit_dist min_edit_dist = edit_dist
@ -205,6 +227,7 @@ def distance(gt: "Part", ocr: "Part") -> "Match":
Using the already available `editops()` function with the Levenshtein distance. Using the already available `editops()` function with the Levenshtein distance.
TODO: replace with @cache annotation in Python 3.9 TODO: replace with @cache annotation in Python 3.9
TODO: rethink @lru_cache
:return: Match object containing the lines and the editing operations. :return: Match object containing the lines and the editing operations.
""" """

@ -33,6 +33,7 @@ COMPLEX_CASES = [
] ]
EXTENDED_CASES = [ EXTENDED_CASES = [
# See figure 4 in 10.1016/j.patrec.2020.02.003
# A: No errors # A: No errors
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
@ -70,14 +71,18 @@ EXTENDED_CASES = [
EDIT_ARGS = "gt,ocr,expected_dist" EDIT_ARGS = "gt,ocr,expected_dist"
SIMPLE_EDITS = [ SIMPLE_EDITS = [
(Part(text="a").substring(), Part(text="a"), Distance(match=1)), (Part(text="a"), Part(text="a"), Distance(match=1)),
(Part(text="a").substring(), Part(text="b"), Distance(replace=1)), (Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
(Part(text="abcd").substring(), Part(text="beed"), (Part(text="abcd"), Part(text="beed"),
Distance(match=2, replace=1, insert=1, delete=1)), Distance(match=2, replace=1, insert=1, delete=1)),
] ]
def extended_case_to_text(gt, ocr): def extended_case_to_text(gt, ocr):
"""Generate sentence from reading order encoding.
See figure 4 in 10.1016/j.patrec.2020.02.003
"""
sentence = ("Eight", "happy", "frogs", "scuba", "dived", sentence = ("Eight", "happy", "frogs", "scuba", "dived",
"Jenny", "chick", "flaps", "white", "wings", "Jenny", "chick", "flaps", "white", "wings",
"", "\n") "", "\n")
@ -93,38 +98,45 @@ def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_
assert score == pytest.approx(all_line_score) assert score == pytest.approx(all_line_score)
@pytest.mark.xfail @pytest.mark.parametrize("config,ocr", [
@pytest.mark.parametrize("ocr", [ ("Config I",
"1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein", "1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein"
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\"", ),
("Config II",
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\""
),
("Config III",
"Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"" "Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\""
),
]) ])
def test_flexible_character_accuracy(ocr): def test_flexible_character_accuracy(config, ocr):
"""Tests from figure 3 in the paper. """Tests from figure 3 in the paper."""
gt = "\"I have\nno special\ntalent." \
TODO: We have a 2 percent deviation from the original because of redistributed "\nI am only\npassionately\ncurious.\"" \
one character alignments (e.g. the y-insert replaces the y-delete). "\nAlbert\nEinstein"
""" replacements, inserts, deletes = 3, 5, 7
gt = """"I have
no special
talent.
I am only
passionately
curious."
Albert
Einstein
"""
replacements = 3
inserts = 5
deletes = 7
chars = len(gt) - gt.count("\n") chars = len(gt) - gt.count("\n")
assert replacements + inserts + deletes == 15 assert chars == 68
edits = Distance(match=chars - deletes - replacements, replace=replacements,
# We consider whitespace as error and in Config II two additional
# whitespaces have been introduced. One will be counted as insert.
# The other whitespace will be counted as replacement,
# additionally reducing the number of deletes.
if config == "Config II":
inserts += 1
replacements += 1
deletes -= 1
expected_dist = Distance(match=chars - deletes - replacements, replace=replacements,
insert=inserts, delete=deletes) insert=inserts, delete=deletes)
expected = character_accuracy(edits) expected_score = character_accuracy(expected_dist)
assert expected == pytest.approx(0.779, abs=0.0005)
result, matches = flexible_character_accuracy(gt, ocr) result, matches = flexible_character_accuracy(gt, ocr)
assert result == pytest.approx(expected, abs=0.0005) agg = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
matches, Counter())
dist = Distance(**agg)
assert dist == expected_dist
assert result == pytest.approx(expected_score, abs=0.0005)
@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES) @pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES)
@ -191,9 +203,15 @@ def test_remove_or_split(original, match, expected_lines):
@pytest.mark.parametrize(EDIT_ARGS, [ @pytest.mark.parametrize(EDIT_ARGS, [
*SIMPLE_EDITS, *SIMPLE_EDITS,
(Part(text="a"), Part(text="b"), Distance(delete=1)),
(Part(text="aaa"), Part(text="bbb"), Distance(delete=3)),
(Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)), (Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)),
(Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)), (Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)),
(Part(text=""), Part(text=""), None) (Part(text=""), Part(text=""), None),
(Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)),
(Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)),
(Part(text="aaabbbaaaddd"), Part(text="aaabcbaaa"), Distance(match=8, replace=1)),
(Part(text="aaabbbccc"), Part(text="aaabbbdddccc"), Distance(match=9, insert=3)),
]) ])
def test_match_lines(gt, ocr, expected_dist): def test_match_lines(gt, ocr, expected_dist):
match = match_lines(gt, ocr) match = match_lines(gt, ocr)
@ -210,6 +228,8 @@ def test_match_lines(gt, ocr, expected_dist):
(Part(text="").substring(), Part(text=""), Distance()), (Part(text="").substring(), Part(text=""), Distance()),
(Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)), (Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)),
(Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)), (Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)),
(Part(text="a"), Part(text="b"), Distance(replace=1)),
(Part(text="aaa"), Part(text="bbb"), Distance(replace=3)),
]) ])
def test_distance(gt, ocr, expected_dist): def test_distance(gt, ocr, expected_dist):
match = distance(gt, ocr) match = distance(gt, ocr)

Loading…
Cancel
Save