Performance increases

Temporarily switch to the c-implementation of python-levenshtein for
editops calculatation. Also added some variables, caching and type
changes for performance gains.
pull/47/head
Benjamin Rosemann 4 years ago
parent 0ef7810dd0
commit b24d8d5664

@ -17,9 +17,10 @@ from functools import lru_cache, reduce
from itertools import product, takewhile from itertools import product, takewhile
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from Levenshtein import editops
from multimethod import multimethod from multimethod import multimethod
from . import editops, ExtractedText from . import ExtractedText
if sys.version_info.minor == 5: if sys.version_info.minor == 5:
from .flexible_character_accuracy_ds_35 import ( from .flexible_character_accuracy_ds_35 import (
@ -170,10 +171,21 @@ def match_gt_line(
""" """
min_penalty = float("inf") min_penalty = float("inf")
best_match, best_ocr = None, None best_match, best_ocr = None, None
gt_line_length = gt_line.length
gt_line_start = gt_line.start
for ocr_line in ocr_lines: for ocr_line in ocr_lines:
match = match_lines(gt_line, ocr_line) match = match_lines(gt_line, ocr_line)
if match: if match:
penalty = calculate_penalty(gt_line, ocr_line, match, coef) penalty = calculate_penalty(
gt_line_length,
ocr_line.length,
gt_line_start,
ocr_line.start,
match.gt.start,
match.ocr.start,
match.dist,
coef,
)
if penalty < min_penalty: if penalty < min_penalty:
min_penalty, best_match, best_ocr = penalty, match, ocr_line min_penalty, best_match, best_ocr = penalty, match, ocr_line
return best_match, best_ocr return best_match, best_ocr
@ -234,7 +246,7 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
for i, gt_part in gt_parts: for i, gt_part in gt_parts:
for j, ocr_part in ocr_parts: for j, ocr_part in ocr_parts:
match = distance(gt_part, ocr_part) match = distance(gt_part, ocr_part)
edit_dist = score_edit_distance(match) edit_dist = score_edit_distance(match.dist)
if edit_dist < min_edit_dist and match.dist.replace < min_length: if edit_dist < min_edit_dist and match.dist.replace < min_length:
min_edit_dist = edit_dist min_edit_dist = edit_dist
best_match = match best_match = match
@ -248,13 +260,13 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
gt_line.substring(rel_start=best_i, rel_end=best_i + k), gt_line.substring(rel_start=best_i, rel_end=best_i + k),
ocr_line.substring(rel_start=best_j, rel_end=best_j + k), ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
) )
edit_dist = score_edit_distance(match) edit_dist = score_edit_distance(match.dist)
if edit_dist < min_edit_dist and match.dist.replace < min_length: if edit_dist < min_edit_dist and match.dist.replace < min_length:
min_edit_dist = edit_dist min_edit_dist = edit_dist
best_match = match best_match = match
# is delete a better option? # is delete a better option?
match = distance(gt_line, Part(text="", line=ocr_line.line, start=ocr_line.start)) match = distance(gt_line, Part(text="", line=ocr_line.line, start=ocr_line.start))
edit_dist = score_edit_distance(match) edit_dist = score_edit_distance(match.dist)
if edit_dist < min_edit_dist: if edit_dist < min_edit_dist:
best_match = match best_match = match
@ -278,18 +290,26 @@ def distance(gt: "Part", ocr: "Part") -> Match:
return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops) return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops)
def score_edit_distance(match: Match) -> int: def score_edit_distance(dist: Distance) -> int:
"""Calculate edit distance for a match. """Calculate edit distance for a match.
Formula: $deletes + inserts + 2 * replacements$ Formula: $deletes + inserts + 2 * replacements$
:return: Sum of deletes, inserts and replacements. :return: Sum of deletes, inserts and replacements.
""" """
return match.dist.delete + match.dist.insert + 2 * match.dist.replace return dist.delete + dist.insert + 2 * dist.replace
@lru_cache(1000000)
def calculate_penalty( def calculate_penalty(
gt: "Part", ocr: "Part", match: Match, coef: Coefficients gt_length: int,
ocr_length: int,
gt_start: int,
ocr_start: int,
gt_match_start: int,
ocr_match_start: int,
dist: Distance,
coef: Coefficients,
) -> float: ) -> float:
"""Calculate the penalty for a given match. """Calculate the penalty for a given match.
@ -297,12 +317,12 @@ def calculate_penalty(
:return: Penalty for the given match. :return: Penalty for the given match.
""" """
min_edit_dist = score_edit_distance(match) min_edit_dist = score_edit_distance(dist)
length_diff = abs(gt.length - ocr.length) length_diff = abs(gt_length - ocr_length)
substring_length = min(gt.length, ocr.length) substring_length = min(gt_length, ocr_length)
offset = 0.0 offset = 0.0
if length_diff > 1: if length_diff > 1:
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start) substring_pos = max(gt_match_start - gt_start, ocr_match_start - ocr_start)
offset = length_diff / 2 - abs(substring_pos - length_diff / 2) offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
return ( return (
min_edit_dist * coef.edit_dist min_edit_dist * coef.edit_dist
@ -428,4 +448,4 @@ class Part(PartVersionSpecific):
""" """
text = self.text[rel_start:rel_end] text = self.text[rel_start:rel_end]
start = self.start + rel_start start = self.start + rel_start
return Part(**{**self._asdict(), "text": text, "start": start}) return Part(line=self.line, text=text, start=start)

@ -70,9 +70,9 @@ SIMPLE_EDITS = [
(Part(text="a"), Part(text="a"), Distance(match=1)), (Part(text="a"), Part(text="a"), Distance(match=1)),
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)), (Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
( (
Part(text="abcd"), Part(text="abbbbcd"),
Part(text="beed"), Part(text="bbbbede"),
Distance(match=2, replace=1, insert=1, delete=1), Distance(match=5, replace=1, insert=1, delete=1),
), ),
] ]

@ -50,7 +50,6 @@ def test_reading_order_settings(file, expected_text):
assert ocr == expected_text assert ocr == expected_text
@pytest.mark.skip(reason="Need to check performance first.")
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gt,ocr,expected", "gt,ocr,expected",

@ -9,3 +9,4 @@ ocrd >= 2.20.1
attrs attrs
multimethod == 1.3 # latest version to officially support Python 3.5 multimethod == 1.3 # latest version to officially support Python 3.5
tqdm tqdm
python-levenshtein

Loading…
Cancel
Save