Performance increases

Temporarily switch to the c-implementation of python-levenshtein for
editops calculatation. Also added some variables, caching and type
changes for performance gains.
pull/47/head
Benjamin Rosemann 4 years ago
parent 0ef7810dd0
commit b24d8d5664

@ -17,9 +17,10 @@ from functools import lru_cache, reduce
from itertools import product, takewhile
from typing import List, Tuple, Optional
from Levenshtein import editops
from multimethod import multimethod
from . import editops, ExtractedText
from . import ExtractedText
if sys.version_info.minor == 5:
from .flexible_character_accuracy_ds_35 import (
@ -170,10 +171,21 @@ def match_gt_line(
"""
min_penalty = float("inf")
best_match, best_ocr = None, None
gt_line_length = gt_line.length
gt_line_start = gt_line.start
for ocr_line in ocr_lines:
match = match_lines(gt_line, ocr_line)
if match:
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
penalty = calculate_penalty(
gt_line_length,
ocr_line.length,
gt_line_start,
ocr_line.start,
match.gt.start,
match.ocr.start,
match.dist,
coef,
)
if penalty < min_penalty:
min_penalty, best_match, best_ocr = penalty, match, ocr_line
return best_match, best_ocr
@ -234,7 +246,7 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
for i, gt_part in gt_parts:
for j, ocr_part in ocr_parts:
match = distance(gt_part, ocr_part)
edit_dist = score_edit_distance(match)
edit_dist = score_edit_distance(match.dist)
if edit_dist < min_edit_dist and match.dist.replace < min_length:
min_edit_dist = edit_dist
best_match = match
@ -248,13 +260,13 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
gt_line.substring(rel_start=best_i, rel_end=best_i + k),
ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
)
edit_dist = score_edit_distance(match)
edit_dist = score_edit_distance(match.dist)
if edit_dist < min_edit_dist and match.dist.replace < min_length:
min_edit_dist = edit_dist
best_match = match
# is delete a better option?
match = distance(gt_line, Part(text="", line=ocr_line.line, start=ocr_line.start))
edit_dist = score_edit_distance(match)
edit_dist = score_edit_distance(match.dist)
if edit_dist < min_edit_dist:
best_match = match
@ -278,18 +290,26 @@ def distance(gt: "Part", ocr: "Part") -> Match:
return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops)
def score_edit_distance(match: Match) -> int:
def score_edit_distance(dist: Distance) -> int:
"""Calculate edit distance for a match.
Formula: $deletes + inserts + 2 * replacements$
:return: Sum of deletes, inserts and replacements.
"""
return match.dist.delete + match.dist.insert + 2 * match.dist.replace
return dist.delete + dist.insert + 2 * dist.replace
@lru_cache(1000000)
def calculate_penalty(
gt: "Part", ocr: "Part", match: Match, coef: Coefficients
gt_length: int,
ocr_length: int,
gt_start: int,
ocr_start: int,
gt_match_start: int,
ocr_match_start: int,
dist: Distance,
coef: Coefficients,
) -> float:
"""Calculate the penalty for a given match.
@ -297,12 +317,12 @@ def calculate_penalty(
:return: Penalty for the given match.
"""
min_edit_dist = score_edit_distance(match)
length_diff = abs(gt.length - ocr.length)
substring_length = min(gt.length, ocr.length)
min_edit_dist = score_edit_distance(dist)
length_diff = abs(gt_length - ocr_length)
substring_length = min(gt_length, ocr_length)
offset = 0.0
if length_diff > 1:
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
substring_pos = max(gt_match_start - gt_start, ocr_match_start - ocr_start)
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
return (
min_edit_dist * coef.edit_dist
@ -428,4 +448,4 @@ class Part(PartVersionSpecific):
"""
text = self.text[rel_start:rel_end]
start = self.start + rel_start
return Part(**{**self._asdict(), "text": text, "start": start})
return Part(line=self.line, text=text, start=start)

@ -70,9 +70,9 @@ SIMPLE_EDITS = [
(Part(text="a"), Part(text="a"), Distance(match=1)),
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
(
Part(text="abcd"),
Part(text="beed"),
Distance(match=2, replace=1, insert=1, delete=1),
Part(text="abbbbcd"),
Part(text="bbbbede"),
Distance(match=5, replace=1, insert=1, delete=1),
),
]

@ -50,7 +50,6 @@ def test_reading_order_settings(file, expected_text):
assert ocr == expected_text
@pytest.mark.skip(reason="Need to check performance first.")
@pytest.mark.integration
@pytest.mark.parametrize(
"gt,ocr,expected",

@ -9,3 +9,4 @@ ocrd >= 2.20.1
attrs
multimethod == 1.3 # latest version to officially support Python 3.5
tqdm
python-levenshtein

Loading…
Cancel
Save