Implement version specific data structures

As ocr-d continues the support for Python 3.5 until the end of this year
version specific data structures have been implemented.

When the support for Python 3.5 is dropped the extra file can easily be
removed.
pull/47/head
Benjamin Rosemann 5 years ago
parent 2a215a1062
commit 4a87adc2c7

@ -11,15 +11,31 @@ DOI: https://doi.org/10.1016/j.patrec.2020.02.003
Note that we deviated from the original algorithm at some places. Note that we deviated from the original algorithm at some places.
""" """
import sys
from collections import Counter from collections import Counter
from functools import lru_cache, reduce from functools import lru_cache, reduce
from itertools import product, takewhile from itertools import product, takewhile
from typing import List, NamedTuple, Tuple, Optional from typing import List, Tuple, Optional
from . import editops from . import editops
if sys.version_info.minor == 5:
from .flexible_character_accuracy_ds_35 import (
PartVersionSpecific,
Match,
Distance,
Coefficients,
)
else:
from .flexible_character_accuracy_ds import (
PartVersionSpecific,
Match,
Distance,
Coefficients,
)
def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"]]: def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List[Match]]:
"""Calculate the flexible character accuracy. """Calculate the flexible character accuracy.
Reference: contains steps 1-7 of the flexible character accuracy algorithm. Reference: contains steps 1-7 of the flexible character accuracy algorithm.
@ -53,7 +69,7 @@ def flexible_character_accuracy(gt: str, ocr: str) -> Tuple[float, List["Match"]
return best_score, best_matches return best_score, best_matches
def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Match"]: def match_with_coefficients(gt: str, ocr: str, coef: Coefficients) -> List[Match]:
"""Match ground truth with ocr and considers a given set of coefficients. """Match ground truth with ocr and considers a given set of coefficients.
Reference: contains steps 1 - 6 of the flexible character accuracy algorithm. Reference: contains steps 1 - 6 of the flexible character accuracy algorithm.
@ -88,8 +104,8 @@ def match_with_coefficients(gt: str, ocr: str, coef: "Coefficients") -> List["Ma
def match_longest_gt_lines( def match_longest_gt_lines(
gt_lines: List["Part"], ocr_lines: List["Part"], coef: "Coefficients" gt_lines: List["Part"], ocr_lines: List["Part"], coef: Coefficients
) -> Optional["Match"]: ) -> Optional[Match]:
"""Find the best match for the longest line(s) in ground truth. """Find the best match for the longest line(s) in ground truth.
The longest lines in ground truth are matched against lines in ocr to find the The longest lines in ground truth are matched against lines in ocr to find the
@ -127,8 +143,8 @@ def match_longest_gt_lines(
def match_gt_line( def match_gt_line(
gt_line: "Part", ocr_lines: List["Part"], coef: "Coefficients" gt_line: "Part", ocr_lines: List["Part"], coef: Coefficients
) -> Tuple[Optional["Match"], Optional["Part"]]: ) -> Tuple[Optional[Match], Optional["Part"]]:
"""Match the given ground truth line against all the lines in ocr. """Match the given ground truth line against all the lines in ocr.
Reference: contains steps 3 of the flexible character accuracy algorithm. Reference: contains steps 3 of the flexible character accuracy algorithm.
@ -166,7 +182,7 @@ def remove_or_split(original: "Part", match: "Part", lines: List["Part"]) -> boo
@lru_cache(maxsize=1000000) @lru_cache(maxsize=1000000)
def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]: def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
"""Matches two lines searching for a local alignment. """Matches two lines searching for a local alignment.
The shorter line is moved along the longer line The shorter line is moved along the longer line
@ -228,7 +244,7 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional["Match"]:
@lru_cache(maxsize=1000000) @lru_cache(maxsize=1000000)
def distance(gt: "Part", ocr: "Part") -> "Match": def distance(gt: "Part", ocr: "Part") -> Match:
"""Calculate the editing distance between the two lines. """Calculate the editing distance between the two lines.
Using the already available `editops()` function with the Levenshtein distance. Using the already available `editops()` function with the Levenshtein distance.
@ -244,7 +260,7 @@ def distance(gt: "Part", ocr: "Part") -> "Match":
return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops) return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops)
def score_edit_distance(match: "Match") -> int: def score_edit_distance(match: Match) -> int:
"""Calculate edit distance for a match. """Calculate edit distance for a match.
Formula: $deletes + inserts + 2 * replacements$ Formula: $deletes + inserts + 2 * replacements$
@ -254,9 +270,7 @@ def score_edit_distance(match: "Match") -> int:
return match.dist.delete + match.dist.insert + 2 * match.dist.replace return match.dist.delete + match.dist.insert + 2 * match.dist.replace
def calculate_penalty( def calculate_penalty(gt: "Part", ocr: "Part", match: Match, coef: Coefficients) -> float:
gt: "Part", ocr: "Part", match: "Match", coef: "Coefficients"
) -> float:
"""Calculate the penalty for a given match. """Calculate the penalty for a given match.
For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003. For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003.
@ -278,21 +292,21 @@ def calculate_penalty(
) )
def character_accuracy_for_matches(matches: List["Match"]) -> float: def character_accuracy_for_matches(matches: List[Match]) -> float:
"""Character accuracy of a full text represented by a list of matches. """Character accuracy of a full text represented by a list of matches.
See other `character_accuracy` for details. See other `character_accuracy` for details.
""" """
agg: Counter = reduce( agg = reduce(
lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter()
) ) # type: Counter
score = character_accuracy(Distance(**agg)) score = character_accuracy(Distance(**agg))
return score return score
def character_accuracy(edits: "Distance") -> float: def character_accuracy(edits: Distance) -> float:
"""Character accuracy calculated by necessary edit operations. """Character accuracy calculated by necessary edit operations.
Edit operations are needed edits to transform one text into another. Edit operations are needed edits to transform one text into another.
@ -335,7 +349,7 @@ def initialize_lines(text: str) -> List["Part"]:
return lines return lines
def combine_lines(matches: List["Match"]) -> Tuple[str, str]: def combine_lines(matches: List[Match]) -> Tuple[str, str]:
"""Combines the matches to aligned texts. """Combines the matches to aligned texts.
TODO: just hacked, needs tests and refinement. Also missing insert/delete marking. TODO: just hacked, needs tests and refinement. Also missing insert/delete marking.
@ -356,16 +370,7 @@ def combine_lines(matches: List["Match"]) -> Tuple[str, str]:
return gt, ocr return gt, ocr
class Part(NamedTuple): class Part(PartVersionSpecific):
"""Represent a line or part of a line.
This data object is maintained to be able to reproduce the original text.
"""
text: str = ""
line: int = 0
start: int = 0
@property @property
def end(self) -> int: def end(self) -> int:
return self.start + self.length return self.start + self.length
@ -402,33 +407,3 @@ class Part(NamedTuple):
text = self.text[rel_start:rel_end] text = self.text[rel_start:rel_end]
start = self.start + rel_start start = self.start + rel_start
return Part(text=text, line=self.line, start=start) return Part(text=text, line=self.line, start=start)
class Distance(NamedTuple):
"""Represent distance between two sequences."""
match: int = 0
replace: int = 0
delete: int = 0
insert: int = 0
class Match(NamedTuple):
"""Represent a calculated match between ground truth and the ocr result."""
gt: "Part"
ocr: "Part"
dist: "Distance"
ops: List
class Coefficients(NamedTuple):
"""Coefficients to calculate penalty for substrings.
See Section 3 in doi:10.1016/j.patrec.2020.02.003
"""
edit_dist: int = 25
length_diff: int = 20
offset: int = 1
length: int = 4

@ -0,0 +1,48 @@
"""
Datastructures to be used with the Flexible Character Accuracy Algorithm
Separated because of version compatibility issues with Python 3.5.
"""
from typing import List, NamedTuple
class PartVersionSpecific(NamedTuple):
"""Represent a line or part of a line.
This data object is maintained to be able to reproduce the original text.
"""
text: str = ""
line: int = 0
start: int = 0
class Distance(NamedTuple):
"""Represent distance between two sequences."""
match: int = 0
replace: int = 0
delete: int = 0
insert: int = 0
class Match(NamedTuple):
"""Represent a calculated match between ground truth and the ocr result."""
gt: "Part"
ocr: "Part"
dist: "Distance"
ops: List
class Coefficients(NamedTuple):
"""Coefficients to calculate penalty for substrings.
See Section 3 in doi:10.1016/j.patrec.2020.02.003
"""
edit_dist: int = 25
length_diff: int = 20
offset: int = 1
length: int = 4

@ -0,0 +1,76 @@
"""
Datastructures to be used with the Flexible Character Accuracy Algorithm
Separated because of version compatibility issues with Python 3.5.
"""
from collections import namedtuple
from typing import Dict
class PartVersionSpecific:
def __init__(self, text: str = "", line: int = 0, start: int = 0):
self.text = text
self.line = line
self.start = start
def __eq__(self, other):
return (
self.line == other.line
and self.start == other.start
and self.text == other.text
)
def __hash__(self):
return hash(self.text) ^ hash(self.line) ^ hash(self.start)
class Distance:
def __init__(
self, match: int = 0, replace: int = 0, delete: int = 0, insert: int = 0
):
self.match = match
self.replace = replace
self.delete = delete
self.insert = insert
def _asdict(self) -> Dict:
return {
"match": self.match,
"replace": self.replace,
"delete": self.delete,
"insert": self.insert,
}
def __eq__(self, other):
return (
self.match == other.match
and self.replace == other.replace
and self.delete == other.delete
and self.insert == other.insert
)
def __hash__(self):
return (
hash(self.match)
^ hash(self.replace)
^ hash(self.delete)
^ hash(self.insert)
)
Match = namedtuple("Match", ["gt", "ocr", "dist", "ops"])
class Coefficients:
def __init__(
self,
edit_dist: int = 25,
length_diff: int = 20,
offset: int = 1,
length: int = 4,
):
self.edit_dist = edit_dist
self.length_diff = length_diff
self.offset = offset
self.length = length
Loading…
Cancel
Save