Reduce number of splits for short (one char) elements

pull/47/head
Benjamin Rosemann 4 years ago
parent c9219cbacd
commit 0ef7810dd0

@ -145,9 +145,11 @@ def match_longest_gt_lines(
score = 0 if not match else character_accuracy(match.dist) score = 0 if not match else character_accuracy(match.dist)
if score > best_score: if score > best_score:
best_score, best_match, best_gt, best_ocr = score, match, gt_line, ocr_line best_score, best_match, best_gt, best_ocr = score, match, gt_line, ocr_line
# early breaking: we only need one perfect fit
if best_score >= 1:
break
# Step 4 of the flexible character accuracy algorithm. # Step 4 of the flexible character accuracy algorithm.
# Remove on full match or split.
if best_match: if best_match:
remove_or_split(best_gt, best_match.gt, gt_lines) remove_or_split(best_gt, best_match.gt, gt_lines)
remove_or_split(best_ocr, best_match.ocr, ocr_lines) remove_or_split(best_ocr, best_match.ocr, ocr_lines)
@ -168,7 +170,7 @@ def match_gt_line(
""" """
min_penalty = float("inf") min_penalty = float("inf")
best_match, best_ocr = None, None best_match, best_ocr = None, None
for ocr_line in [*ocr_lines]: for ocr_line in ocr_lines:
match = match_lines(gt_line, ocr_line) match = match_lines(gt_line, ocr_line)
if match: if match:
penalty = calculate_penalty(gt_line, ocr_line, match, coef) penalty = calculate_penalty(gt_line, ocr_line, match, coef)
@ -233,7 +235,7 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
for j, ocr_part in ocr_parts: for j, ocr_part in ocr_parts:
match = distance(gt_part, ocr_part) match = distance(gt_part, ocr_part)
edit_dist = score_edit_distance(match) edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist: if edit_dist < min_edit_dist and match.dist.replace < min_length:
min_edit_dist = edit_dist min_edit_dist = edit_dist
best_match = match best_match = match
best_i, best_j = i, j best_i, best_j = i, j
@ -247,7 +249,7 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
ocr_line.substring(rel_start=best_j, rel_end=best_j + k), ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
) )
edit_dist = score_edit_distance(match) edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist: if edit_dist < min_edit_dist and match.dist.replace < min_length:
min_edit_dist = edit_dist min_edit_dist = edit_dist
best_match = match best_match = match
# is delete a better option? # is delete a better option?

@ -26,6 +26,7 @@ SIMPLE_CASES = [
("bbb", "aaa\nbbb\nccc", 1, 1 - 6 / 3), ("bbb", "aaa\nbbb\nccc", 1, 1 - 6 / 3),
("a", "a\nbb\nccc", 1, 1 - 5 / 1), ("a", "a\nbb\nccc", 1, 1 - 5 / 1),
("bb", "a\nbb\nccc", 1, 1 - 4 / 2), ("bb", "a\nbb\nccc", 1, 1 - 4 / 2),
("abcd", "ab\ne", 1, 1 - 3 / 4),
] ]
COMPLEX_CASES = [ COMPLEX_CASES = [
@ -135,7 +136,6 @@ def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_sco
assert score == pytest.approx(all_line_score) assert score == pytest.approx(all_line_score)
@pytest.mark.xfail(reason="Need to adapt performance details.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config,ocr", "config,ocr",
[ [
@ -273,6 +273,8 @@ def test_remove_or_split(original, match, expected_lines):
[ [
*SIMPLE_EDITS, *SIMPLE_EDITS,
(Part(text="a"), Part(text="b"), Distance(delete=1)), (Part(text="a"), Part(text="b"), Distance(delete=1)),
(Part(text="ab"), Part(text="c"), Distance(delete=2)),
(Part(text="abc"), Part(text="d"), Distance(delete=3)),
(Part(text="aaa"), Part(text="bbb"), Distance(delete=3)), (Part(text="aaa"), Part(text="bbb"), Distance(delete=3)),
(Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)), (Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)),
(Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)), (Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)),

Loading…
Cancel
Save