mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-07 19:05:13 +02:00
Performance increases
Temporarily switch to the c-implementation of python-levenshtein for editops calculatation. Also added some variables, caching and type changes for performance gains.
This commit is contained in:
parent
0ef7810dd0
commit
b24d8d5664
4 changed files with 37 additions and 17 deletions
|
@ -17,9 +17,10 @@ from functools import lru_cache, reduce
|
||||||
from itertools import product, takewhile
|
from itertools import product, takewhile
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
from Levenshtein import editops
|
||||||
from multimethod import multimethod
|
from multimethod import multimethod
|
||||||
|
|
||||||
from . import editops, ExtractedText
|
from . import ExtractedText
|
||||||
|
|
||||||
if sys.version_info.minor == 5:
|
if sys.version_info.minor == 5:
|
||||||
from .flexible_character_accuracy_ds_35 import (
|
from .flexible_character_accuracy_ds_35 import (
|
||||||
|
@ -170,10 +171,21 @@ def match_gt_line(
|
||||||
"""
|
"""
|
||||||
min_penalty = float("inf")
|
min_penalty = float("inf")
|
||||||
best_match, best_ocr = None, None
|
best_match, best_ocr = None, None
|
||||||
|
gt_line_length = gt_line.length
|
||||||
|
gt_line_start = gt_line.start
|
||||||
for ocr_line in ocr_lines:
|
for ocr_line in ocr_lines:
|
||||||
match = match_lines(gt_line, ocr_line)
|
match = match_lines(gt_line, ocr_line)
|
||||||
if match:
|
if match:
|
||||||
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
|
penalty = calculate_penalty(
|
||||||
|
gt_line_length,
|
||||||
|
ocr_line.length,
|
||||||
|
gt_line_start,
|
||||||
|
ocr_line.start,
|
||||||
|
match.gt.start,
|
||||||
|
match.ocr.start,
|
||||||
|
match.dist,
|
||||||
|
coef,
|
||||||
|
)
|
||||||
if penalty < min_penalty:
|
if penalty < min_penalty:
|
||||||
min_penalty, best_match, best_ocr = penalty, match, ocr_line
|
min_penalty, best_match, best_ocr = penalty, match, ocr_line
|
||||||
return best_match, best_ocr
|
return best_match, best_ocr
|
||||||
|
@ -234,7 +246,7 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
|
||||||
for i, gt_part in gt_parts:
|
for i, gt_part in gt_parts:
|
||||||
for j, ocr_part in ocr_parts:
|
for j, ocr_part in ocr_parts:
|
||||||
match = distance(gt_part, ocr_part)
|
match = distance(gt_part, ocr_part)
|
||||||
edit_dist = score_edit_distance(match)
|
edit_dist = score_edit_distance(match.dist)
|
||||||
if edit_dist < min_edit_dist and match.dist.replace < min_length:
|
if edit_dist < min_edit_dist and match.dist.replace < min_length:
|
||||||
min_edit_dist = edit_dist
|
min_edit_dist = edit_dist
|
||||||
best_match = match
|
best_match = match
|
||||||
|
@ -248,13 +260,13 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
|
||||||
gt_line.substring(rel_start=best_i, rel_end=best_i + k),
|
gt_line.substring(rel_start=best_i, rel_end=best_i + k),
|
||||||
ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
|
ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
|
||||||
)
|
)
|
||||||
edit_dist = score_edit_distance(match)
|
edit_dist = score_edit_distance(match.dist)
|
||||||
if edit_dist < min_edit_dist and match.dist.replace < min_length:
|
if edit_dist < min_edit_dist and match.dist.replace < min_length:
|
||||||
min_edit_dist = edit_dist
|
min_edit_dist = edit_dist
|
||||||
best_match = match
|
best_match = match
|
||||||
# is delete a better option?
|
# is delete a better option?
|
||||||
match = distance(gt_line, Part(text="", line=ocr_line.line, start=ocr_line.start))
|
match = distance(gt_line, Part(text="", line=ocr_line.line, start=ocr_line.start))
|
||||||
edit_dist = score_edit_distance(match)
|
edit_dist = score_edit_distance(match.dist)
|
||||||
if edit_dist < min_edit_dist:
|
if edit_dist < min_edit_dist:
|
||||||
best_match = match
|
best_match = match
|
||||||
|
|
||||||
|
@ -278,18 +290,26 @@ def distance(gt: "Part", ocr: "Part") -> Match:
|
||||||
return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops)
|
return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops)
|
||||||
|
|
||||||
|
|
||||||
def score_edit_distance(match: Match) -> int:
|
def score_edit_distance(dist: Distance) -> int:
|
||||||
"""Calculate edit distance for a match.
|
"""Calculate edit distance for a match.
|
||||||
|
|
||||||
Formula: $deletes + inserts + 2 * replacements$
|
Formula: $deletes + inserts + 2 * replacements$
|
||||||
|
|
||||||
:return: Sum of deletes, inserts and replacements.
|
:return: Sum of deletes, inserts and replacements.
|
||||||
"""
|
"""
|
||||||
return match.dist.delete + match.dist.insert + 2 * match.dist.replace
|
return dist.delete + dist.insert + 2 * dist.replace
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(1000000)
|
||||||
def calculate_penalty(
|
def calculate_penalty(
|
||||||
gt: "Part", ocr: "Part", match: Match, coef: Coefficients
|
gt_length: int,
|
||||||
|
ocr_length: int,
|
||||||
|
gt_start: int,
|
||||||
|
ocr_start: int,
|
||||||
|
gt_match_start: int,
|
||||||
|
ocr_match_start: int,
|
||||||
|
dist: Distance,
|
||||||
|
coef: Coefficients,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Calculate the penalty for a given match.
|
"""Calculate the penalty for a given match.
|
||||||
|
|
||||||
|
@ -297,12 +317,12 @@ def calculate_penalty(
|
||||||
|
|
||||||
:return: Penalty for the given match.
|
:return: Penalty for the given match.
|
||||||
"""
|
"""
|
||||||
min_edit_dist = score_edit_distance(match)
|
min_edit_dist = score_edit_distance(dist)
|
||||||
length_diff = abs(gt.length - ocr.length)
|
length_diff = abs(gt_length - ocr_length)
|
||||||
substring_length = min(gt.length, ocr.length)
|
substring_length = min(gt_length, ocr_length)
|
||||||
offset = 0.0
|
offset = 0.0
|
||||||
if length_diff > 1:
|
if length_diff > 1:
|
||||||
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
|
substring_pos = max(gt_match_start - gt_start, ocr_match_start - ocr_start)
|
||||||
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
|
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
|
||||||
return (
|
return (
|
||||||
min_edit_dist * coef.edit_dist
|
min_edit_dist * coef.edit_dist
|
||||||
|
@ -428,4 +448,4 @@ class Part(PartVersionSpecific):
|
||||||
"""
|
"""
|
||||||
text = self.text[rel_start:rel_end]
|
text = self.text[rel_start:rel_end]
|
||||||
start = self.start + rel_start
|
start = self.start + rel_start
|
||||||
return Part(**{**self._asdict(), "text": text, "start": start})
|
return Part(line=self.line, text=text, start=start)
|
||||||
|
|
|
@ -70,9 +70,9 @@ SIMPLE_EDITS = [
|
||||||
(Part(text="a"), Part(text="a"), Distance(match=1)),
|
(Part(text="a"), Part(text="a"), Distance(match=1)),
|
||||||
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
|
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
|
||||||
(
|
(
|
||||||
Part(text="abcd"),
|
Part(text="abbbbcd"),
|
||||||
Part(text="beed"),
|
Part(text="bbbbede"),
|
||||||
Distance(match=2, replace=1, insert=1, delete=1),
|
Distance(match=5, replace=1, insert=1, delete=1),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,6 @@ def test_reading_order_settings(file, expected_text):
|
||||||
assert ocr == expected_text
|
assert ocr == expected_text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Need to check performance first.")
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"gt,ocr,expected",
|
"gt,ocr,expected",
|
||||||
|
|
|
@ -9,3 +9,4 @@ ocrd >= 2.20.1
|
||||||
attrs
|
attrs
|
||||||
multimethod == 1.3 # latest version to officially support Python 3.5
|
multimethod == 1.3 # latest version to officially support Python 3.5
|
||||||
tqdm
|
tqdm
|
||||||
|
python-levenshtein
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue