First draft of flexible character accuracy

pull/47/head
Benjamin Rosemann 5 years ago
parent bd324331e6
commit d7a74fa58b

@ -1,4 +0,0 @@
[pytest]
markers =
integration: integration tests
serial

@ -0,0 +1,394 @@
"""
Implementation of the flexible character accuracy
Citation:
Flexible character accuracy measure for reading-order-independent evaluation
C. Clausner, S. Pletschacher, A. Antonacopoulos
Pattern Recognition Letters, Volume 131, March 2020, Pages 390-397
Link: http://www.primaresearch.org/publications/PRL_Clausner_FlexibleCharacterAccuracy
DOI: https://doi.org/10.1016/j.patrec.2020.02.003
Note that we deviated from the original algorithm at some places.
"""
from collections import Counter
from functools import lru_cache, reduce
from itertools import product, takewhile
from typing import List, NamedTuple, Tuple, Optional
from . import editops
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"]]:
"""Calculate the flexible character accuracy.
Reference: contains steps 1-7 of the flexible character accuracy algorithm.
:param gt: The ground truth text.
:param ocr: The text to compare the ground truth with.
:return: Score between 0 and 1 and match objects.
"""
best_score = -float('inf')
best_matches = []
# TODO: this should be configurable
combinations = product(range(15, 31, 5),
range(0, 24, 3),
range(0, 4, 1),
range(0, 6, 1))
# TODO: place to parallelize the algorithm
for (edit_dist, length_diff, offset, length) in combinations:
coef = Coefficients(
edit_dist=edit_dist,
length_diff=length_diff,
offset=offset,
length=length
)
# Steps 1 - 6 of the flexible character accuracy algorithm.
matches = match_with_coefficients(gt, ocr, coef)
# Step 7 of the flexible character accuracy algorithm.
score = character_accuracy_for_matches(matches)
if score > best_score:
best_score = score
best_matches = matches
# early breaking: we only need one perfect fit
if best_score >= 1:
break
return best_score, best_matches
def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Match"]:
"""Match ground truth with ocr and considers a given set of coefficients.
Reference: contains steps 1 - 6 of the flexible character accuracy algorithm.
:return: A list of match objects to score and align the texts.
"""
# Steps 1 and 2 of the flexible character accuracy algorithm.
ocr_lines = initialize_lines(ocr)
gt_lines = initialize_lines(gt)
matches = []
# Step 5 of the flexible character accuracy algorithm.
while len(gt_lines) != 0 and len(ocr_lines) != 0:
# Steps 3 and 4 of the flexible character accuracy algorithm.
match = match_longest_gt_lines(gt_lines, ocr_lines, coef)
if match:
matches.append(match)
# Step 6 of the flexible character accuracy algorithm.
# remaining lines are considered as deletes and inserts
deletes = [distance(line, Part(text="", line=line.line, start=line.start))
for line in gt_lines]
inserts = [distance(Part(text="", line=line.line, start=line.start), line)
for line in ocr_lines]
return [*matches, *deletes, *inserts]
def match_longest_gt_lines(gt_lines: List["Part"],
ocr_lines: List["Part"],
coef: "Coefficients") -> Optional["Match"]:
"""Find the best match for the longest line(s) in ground truth.
The longest lines in ground truth are matched against lines in ocr to find the
best matching pair. This pair is then either considered a match on full line
Reference: contains steps 3 and 4 of the flexible character accuracy algorithm.
:return: Possible match object.
"""
best_score, best_match, best_gt, best_ocr = -float('inf'), None, None, None
if not ocr_lines:
return best_match
# Step 3 of the flexible character accuracy algorithm (variation).
# Instead of the longest line we take all longest lines with equal length.
length = min(gt_lines[0].length, ocr_lines[0].length)
for gt_line in takewhile(lambda line: line.length >= length, gt_lines):
match, ocr_line = match_gt_line(gt_line, ocr_lines, coef)
score = 0 if not match else character_accuracy(match.dist)
if score > best_score:
best_score, best_match, best_gt, best_ocr = score, match, gt_line, ocr_line
# Step 4 of the flexible character accuracy algorithm.
# Remove on full match or split.
if best_match and best_gt:
splitted = remove_or_split(best_gt, best_match.gt, gt_lines)
if splitted:
gt_lines.append(best_match.gt)
best_match = None
if best_match and best_ocr:
remove_or_split(best_ocr, best_match.ocr, ocr_lines)
return best_match
def match_gt_line(gt_line: "Part",
ocr_lines: List["Part"],
coef: "Coefficients") -> Tuple[Optional["Match"],
Optional["Part"]]:
"""Match the given ground truth line against all the lines in ocr.
Reference: contains steps 3 of the flexible character accuracy algorithm.
TODO: Make penalty function configurable?
TODO: Add empty ocr line to avoid having nonesense one character alignments?
:return: Match object and the matched ocr line.
"""
min_penalty = float('inf')
best_match, best_ocr = None, None
for ocr_line in ocr_lines:
match = match_lines(gt_line, ocr_line)
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
if penalty < min_penalty:
min_penalty, best_match, best_ocr = penalty, match, ocr_line
return best_match, best_ocr
def remove_or_split(original: "Part",
match: "Part",
lines: List["Part"]) -> bool:
"""Removes the matched line or splits it into parts.
Reference: contains step 4 of the flexible character accuracy algorithm.
:return: True if line was splitted.
"""
splitted = False
del lines[lines.index(original)]
if match.length < original.length:
lines.extend(original.split(match))
# sorting for ocr is not mentioned in the paper, but is used as tie breaking =)
lines.sort(key=lambda x: x.length, reverse=True)
splitted = True
return splitted
@lru_cache(maxsize=1000000)
def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
"""Matches two lines searching for a local alignment.
The shorter line is moved along the longer line
until the editing distance is minimized.
Reference: see figure 2 in the paper.
TODO: make distance function configurable?
:return: Match object if one is found.
"""
min_length = min(gt_line.length, ocr_line.length)
best_match = None
if min_length == 0:
return best_match
length_diff = gt_line.length - ocr_line.length
min_edit_dist = float('inf')
# TODO: handle deletes and replacements by extending the length.
for i in range(0, max(1, length_diff + 1)):
for j in range(0, max(1, -1 * length_diff + 1)):
match = distance(gt_line.substring(rel_start=i, rel_end=i + min_length),
ocr_line.substring(rel_start=j, rel_end=j + min_length))
edit_dist = score_edit_distance(match)
if edit_dist < min_edit_dist:
min_edit_dist = edit_dist
best_match = match
return best_match
@lru_cache(maxsize=1000000)
def distance(gt: "Part", ocr: "Part") -> "Match":
"""Calculate the editing distance between the two lines.
Using the already available `editops()` function with the Levenshtein distance.
TODO: replace with @cache annotation in Python 3.9
:return: Match object containing the lines and the editing operations.
"""
ops = editops(gt.text, ocr.text)
edits = Counter([edit[0] for edit in ops])
edits["match"] = gt.length - edits["delete"] - edits["replace"]
return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops)
def score_edit_distance(match: "Match") -> int:
"""Calculate edit distance for a match.
Formula: $deletes + inserts + 2 * replacements$
:return: Sum of deletes, inserts and replacements.
"""
return match.dist.delete + match.dist.insert + 2 * match.dist.replace
def calculate_penalty(gt: "Part", ocr: "Part", match: "Match",
coef: "Coefficients") -> float:
"""Calculate the penalty for a given match.
For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003.
:return: Penalty for the given match.
"""
min_edit_dist = score_edit_distance(match)
length_diff = abs(gt.length - ocr.length)
substring_length = min(gt.length, ocr.length)
offset = 0.0
if length_diff > 1:
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
return (min_edit_dist * coef.edit_dist
+ length_diff * coef.length_diff
+ offset * coef.offset
- substring_length * coef.length)
def character_accuracy_for_matches(matches: List["Match"]) -> float:
"""Character accuracy of a full text represented by a list of matches.
See other `character_accuracy` for details.
"""
agg: Counter = reduce(lambda acc, match: acc + Counter(match.dist._asdict()),
matches, Counter())
score = character_accuracy(Distance(**agg))
return score
def character_accuracy(edits: "Distance") -> float:
"""Character accuracy calculated by necessary edit operations.
Edit operations are needed edits to transform one text into another.
The character accuracy is given by $1 - errors / characters$.
Errors are replacements, deletes and inserts.
Note that is is possible to have more errors than characters in which case the
character accuracy turns negative.
Comparing two empty strings (having no edits) results in a character accuracy of 1.
"""
errors = edits.replace + edits.delete + edits.insert
chars = edits.match + edits.replace + edits.delete
if not chars and not errors:
# comparison of empty strings is considered a full match
score = 1
else:
score = 1 - errors / chars
return score
def initialize_lines(text: str) -> List["Part"]:
"""Splits a text into lines and converts them to our line data object.
The line objects are sorted by their length descending.
Reference: contains steps 1 and 2 of the flexible character accuracy algorithm.
:param text: Text to split into lines.
:return: List of sorted line objects.
"""
lines = [Part(text=line, line=i, start=0)
for i, line in enumerate(text.splitlines())
if len(line) > 0]
lines.sort(key=lambda x: x.length, reverse=True)
return lines
def combine_lines(matches: List["Match"]) -> Tuple[str, str]:
"""Combines the matches to aligned texts.
TODO: just hacked, needs tests and refinement. Also missing insert/delete marking.
:param matches: List of match objects.
:return: the aligned ground truth and ocr as texts.
"""
matches.sort(key=lambda x: x.gt.line + x.gt.start / 10000)
line = 0
gt, ocr = "", ""
for match in matches:
if match.gt.line > line:
gt += "\n"
ocr += "\n"
line += 1
gt += match.gt.text
ocr += match.ocr.text
return gt, ocr
class Part(NamedTuple):
"""Represent a line or part of a line.
This data object is maintained to be able to reproduce the original text.
"""
text: str = ""
line: int = 0
start: int = 0
@property
def end(self) -> int:
return self.start + self.length
@property
def length(self) -> int:
return len(self.text)
def split(self, split: "Part") -> List["Part"]:
"""Split the line part by another and returns the remaining parts.
`abc.split("b")` will return ´["a", "c"]`.
:param split: The line part we want to use to split.
:return: The parts before and after the split.
"""
rest = []
if self.start < split.start:
rest.append(self.substring(rel_end=split.start - self.start))
if split.end < self.end:
rest.append(self.substring(rel_start=split.end - self.start))
return rest
def substring(self, rel_start: int = 0, rel_end: int = None) -> "Part":
"""Get part of the given line.
Automatically handles the offset of the line.
Therefore `substring(rel_start=2)` will return `Part[start+rel_start:]`.
:param rel_start: start relative to the part of the line.
:param rel_end: end relative to the part of the line.
:return: Extracted part of the given part of the line.
"""
text = self.text[rel_start:rel_end]
start = self.start + rel_start
return Part(text=text, line=self.line, start=start)
class Distance(NamedTuple):
"""Represent distance between two sequences."""
match: int = 0
replace: int = 0
delete: int = 0
insert: int = 0
class Match(NamedTuple):
"""Represent a calculated match between ground truth and the ocr result."""
gt: "Part"
ocr: "Part"
dist: "Distance"
ops: List
class Coefficients(NamedTuple):
"""Coefficients to calculate penalty for substrings.
See Section 3 in doi:10.1016/j.patrec.2020.02.003
"""
edit_dist: int = 25
length_diff: int = 20
offset: int = 1
length: int = 4

@ -0,0 +1,291 @@
"""
Tests for the implementation of the flexible character accuracy
Citation:
Flexible character accuracy measure for reading-order-independent evaluation
C. Clausner, S. Pletschacher, A. Antonacopoulos
Pattern Recognition Letters, Volume 131, March 2020, Pages 390-397
Link: http://www.primaresearch.org/publications/PRL_Clausner_FlexibleCharacterAccuracy
DOI: 10.1016/j.patrec.2020.02.003
"""
import pytest
from ..flexible_character_accuracy import *
CASE_ARGS = "gt,ocr,first_line_score,all_line_score"
SIMPLE_CASES = [
("a", "", 0, 0),
("a", "a", 1, 1),
("a\nb", "a\nb", 1, 1),
("a\nb", "b\na", 1, 1),
("aaa\nbbb\nccc", "ccc\naaa\nbbb", 1, 1),
("aaa\nbbb\nccc", "aaa\nbbb", 1, 1 - 3 / 9),
("bbb", "aaa\nbbb\nccc", 1, 1 - 6 / 3),
("a", "a\nbb\nccc", 1, 1 - 5 / 1),
("bb", "a\nbb\nccc", 1, 1 - 4 / 2),
]
COMPLEX_CASES = [
("accc", "a\nbb\nccc", 0, 1 - 2 / 4),
("aaa\nbbb\nccc", "bbb", 1, 1 - 6 / 9),
]
EXTENDED_CASES = [
# A: No errors
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
1, 1),
# B: Different ordering of text blocks
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4),
1, 1),
# C: Merge across columns
((0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9),
(0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9),
1, 0.964),
# D: Over-segmentation
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9),
1, 0.966),
# E: Part missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(0, 1, 2, 3, 4),
1, 0.50),
# E.2: Part missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(5, 6, 7, 8, 9),
1, 0.50),
# F: All missing
((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9),
(),
1, 0),
# G: Added parts
((0, 1, 2, 3, 4),
(0, 1, 2, 3, 4, 11, 5, 6),
1, 0.621),
]
EDIT_ARGS = "gt,ocr,expected_dist"
SIMPLE_EDITS = [
(Part(text="a").substring(), Part(text="a"), Distance(match=1)),
(Part(text="a").substring(), Part(text="b"), Distance(replace=1)),
(Part(text="abcd").substring(), Part(text="beed"),
Distance(match=2, replace=1, insert=1, delete=1)),
]
def extended_case_to_text(gt, ocr):
sentence = ("Eight", "happy", "frogs", "scuba", "dived",
"Jenny", "chick", "flaps", "white", "wings",
"", "\n")
gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n")
ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n")
return gt_sentence, ocr_sentence
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
def test_flexible_character_accuracy_simple(gt, ocr, first_line_score, all_line_score):
score, _ = flexible_character_accuracy(gt, ocr)
assert score == pytest.approx(all_line_score)
@pytest.mark.xfail
@pytest.mark.parametrize("ocr", [
"1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\"\nAlberto\nEmstein",
"1 hav\nnospecial\ntalents. Alberto\nI am one Emstein\npassionate\ncuriousity.\"",
"Alberto\nEmstein\n1 hav\nnospecial\ntalents.\nI am one\npassionate\ncuriousity.\""
])
def test_flexible_character_accuracy(ocr):
"""Tests from figure 3 in the paper.
TODO: We have a 2 percent deviation from the original because of redistributed
one character alignments (e.g. the y-insert replaces the y-delete).
"""
gt = """"I have
no special
talent.
I am only
passionately
curious."
Albert
Einstein
"""
replacements = 3
inserts = 5
deletes = 7
chars = len(gt) - gt.count("\n")
assert replacements + inserts + deletes == 15
edits = Distance(match=chars - deletes - replacements, replace=replacements,
insert=inserts, delete=deletes)
expected = character_accuracy(edits)
assert expected == pytest.approx(0.779, abs=0.0005)
result, matches = flexible_character_accuracy(gt, ocr)
assert result == pytest.approx(expected, abs=0.0005)
@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES)
def test_flexible_character_accuracy_extended(gt, ocr, first_line_score,
all_line_score):
"""Tests from figure 4 in the paper."""
gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr)
result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence)
assert result == pytest.approx(all_line_score, abs=0.001)
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES, *EXTENDED_CASES])
def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score):
coef = Coefficients()
if not isinstance(gt, str):
gt, ocr = extended_case_to_text(gt, ocr)
matches = match_with_coefficients(gt, ocr, coef)
score = character_accuracy_for_matches(matches)
assert score == pytest.approx(all_line_score, abs=0.001)
@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES])
def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score):
coef = Coefficients()
gt_lines = initialize_lines(gt)
ocr_lines = initialize_lines(ocr)
match = match_longest_gt_lines(gt_lines, ocr_lines, coef)
score = 0
if match:
score = character_accuracy(match.dist)
assert score == pytest.approx(first_line_score)
@pytest.mark.parametrize(CASE_ARGS, [
*SIMPLE_CASES,
("accc", "a\nbb\nccc", 1.0, 1.0),
])
def test_match_gt_line(gt, ocr, first_line_score, all_line_score):
coef = Coefficients()
gt_lines = initialize_lines(gt)
ocr_lines = initialize_lines(ocr)
match, _ = match_gt_line(gt_lines[0], ocr_lines, coef)
score = 0
if match:
score = character_accuracy(match.dist)
assert score == pytest.approx(first_line_score)
@pytest.mark.parametrize("original,match,expected_lines", [
(Part(), Part(), []),
(Part(text="abc"), Part(), [Part(text="abc")]),
(Part(text="abc"), Part("d"), [Part(text="bc", start=1)]),
(Part(text="abc"), Part("a", start=100), [Part(text="abc")]),
(Part(text="abc"), Part("a"), [Part(text="bc", start=1)]),
(Part(text="abc"), Part("b", start=1), [Part(text="a"), Part(text="c", start=2)]),
(Part(text="abc"), Part("c", start=2), [Part(text="ab")]),
])
def test_remove_or_split(original, match, expected_lines):
lines = [original]
splitted = remove_or_split(original, match, lines)
assert splitted == (len(lines) > 0)
assert lines == expected_lines
@pytest.mark.parametrize(EDIT_ARGS, [
*SIMPLE_EDITS,
(Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)),
(Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)),
(Part(text=""), Part(text=""), None)
])
def test_match_lines(gt, ocr, expected_dist):
match = match_lines(gt, ocr)
if not expected_dist:
assert match is None
else:
assert match.gt.text in gt.text
assert match.ocr.text in ocr.text
assert match.dist == expected_dist
@pytest.mark.parametrize(EDIT_ARGS, [
*SIMPLE_EDITS,
(Part(text="").substring(), Part(text=""), Distance()),
(Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)),
(Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)),
])
def test_distance(gt, ocr, expected_dist):
match = distance(gt, ocr)
assert match.gt == gt
assert match.ocr == ocr
assert match.dist == expected_dist
@pytest.mark.parametrize("matches,expected_dist", [
([], 1),
([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1),
([Match(gt=Part(text="abee"), ocr=Part("ac"),
dist=Distance(match=1, replace=1, delete=2), ops=[]),
Match(gt=Part(text="cd"), ocr=Part("ceff"),
dist=Distance(match=1, replace=1, insert=2), ops=[])],
1 - 6 / 6),
])
def test_character_accuracy_matches(matches, expected_dist):
assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist)
@pytest.mark.parametrize("dist,expected_dist", [
(Distance(), 1),
(Distance(match=1), 1),
(Distance(replace=1), 0),
(Distance(match=1, insert=1), 0),
(Distance(match=1, insert=2), 1 - 2 / 1),
(Distance(match=2, insert=1), 0.5),
(Distance(match=1, delete=1), 0.5),
])
def test_character_accuracy_dist(dist, expected_dist):
assert character_accuracy(dist) == pytest.approx(expected_dist)
@pytest.mark.parametrize("line,subline,expected_rest", [
(Part(), Part(), []),
(Part("aaa bbb"), Part("aaa bbb"), []),
(Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]),
(Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]),
(Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]),
(Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]),
(Part("aaa bbb ccc"), Part("bbb", start=4), [Part("aaa "), Part(" ccc", start=7)]),
(Part("aaa bbb ccc", start=3), Part("bbb", start=7),
[Part("aaa ", start=3), Part(" ccc", start=10)]),
(Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]),
(Part("aaa bbb", start=3), Part(" ", start=6),
[Part("aaa", start=3), Part("bbb", start=7)]),
])
def test_split_line(line, subline, expected_rest):
rest = line.split(subline)
assert len(rest) == len(expected_rest)
assert set(rest) == set(expected_rest)
def test_initialize_lines():
lines = initialize_lines("")
assert lines == []
lines = initialize_lines("22\n1\n333")
line1 = Part(text="22", line=0, start=0)
line2 = Part("1", line=1, start=0)
line3 = Part("333", line=2, start=0)
assert lines == [line3, line1, line2]
@pytest.mark.xfail
def test_combine_lines():
assert False
@pytest.mark.parametrize("line,start,end,expected", [
(Part(text=""), 0, None, Part(text="")),
(Part(text="a"), 0, None, Part(text="a")),
(Part(text="ab"), 0, 1, Part(text="a")),
(Part(text="abc"), 0, -1, Part(text="ab")),
(Part(text="ab"), 1, None, Part(text="b", start=1)),
])
def test_line_substring(line, start, end, expected):
assert line.substring(rel_start=start, rel_end=end) == expected
Loading…
Cancel
Save