mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-09 11:50:00 +02:00
Performance increases
Temporarily switch to the c-implementation of python-levenshtein for editops calculatation. Also added some variables, caching and type changes for performance gains.
This commit is contained in:
parent
0ef7810dd0
commit
b24d8d5664
4 changed files with 37 additions and 17 deletions
|
@ -17,9 +17,10 @@ from functools import lru_cache, reduce
|
|||
from itertools import product, takewhile
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from Levenshtein import editops
|
||||
from multimethod import multimethod
|
||||
|
||||
from . import editops, ExtractedText
|
||||
from . import ExtractedText
|
||||
|
||||
if sys.version_info.minor == 5:
|
||||
from .flexible_character_accuracy_ds_35 import (
|
||||
|
@ -170,10 +171,21 @@ def match_gt_line(
|
|||
"""
|
||||
min_penalty = float("inf")
|
||||
best_match, best_ocr = None, None
|
||||
gt_line_length = gt_line.length
|
||||
gt_line_start = gt_line.start
|
||||
for ocr_line in ocr_lines:
|
||||
match = match_lines(gt_line, ocr_line)
|
||||
if match:
|
||||
penalty = calculate_penalty(gt_line, ocr_line, match, coef)
|
||||
penalty = calculate_penalty(
|
||||
gt_line_length,
|
||||
ocr_line.length,
|
||||
gt_line_start,
|
||||
ocr_line.start,
|
||||
match.gt.start,
|
||||
match.ocr.start,
|
||||
match.dist,
|
||||
coef,
|
||||
)
|
||||
if penalty < min_penalty:
|
||||
min_penalty, best_match, best_ocr = penalty, match, ocr_line
|
||||
return best_match, best_ocr
|
||||
|
@ -234,7 +246,7 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
|
|||
for i, gt_part in gt_parts:
|
||||
for j, ocr_part in ocr_parts:
|
||||
match = distance(gt_part, ocr_part)
|
||||
edit_dist = score_edit_distance(match)
|
||||
edit_dist = score_edit_distance(match.dist)
|
||||
if edit_dist < min_edit_dist and match.dist.replace < min_length:
|
||||
min_edit_dist = edit_dist
|
||||
best_match = match
|
||||
|
@ -248,13 +260,13 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
|
|||
gt_line.substring(rel_start=best_i, rel_end=best_i + k),
|
||||
ocr_line.substring(rel_start=best_j, rel_end=best_j + k),
|
||||
)
|
||||
edit_dist = score_edit_distance(match)
|
||||
edit_dist = score_edit_distance(match.dist)
|
||||
if edit_dist < min_edit_dist and match.dist.replace < min_length:
|
||||
min_edit_dist = edit_dist
|
||||
best_match = match
|
||||
# is delete a better option?
|
||||
match = distance(gt_line, Part(text="", line=ocr_line.line, start=ocr_line.start))
|
||||
edit_dist = score_edit_distance(match)
|
||||
edit_dist = score_edit_distance(match.dist)
|
||||
if edit_dist < min_edit_dist:
|
||||
best_match = match
|
||||
|
||||
|
@ -278,18 +290,26 @@ def distance(gt: "Part", ocr: "Part") -> Match:
|
|||
return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops)
|
||||
|
||||
|
||||
def score_edit_distance(match: Match) -> int:
|
||||
def score_edit_distance(dist: Distance) -> int:
|
||||
"""Calculate edit distance for a match.
|
||||
|
||||
Formula: $deletes + inserts + 2 * replacements$
|
||||
|
||||
:return: Sum of deletes, inserts and replacements.
|
||||
"""
|
||||
return match.dist.delete + match.dist.insert + 2 * match.dist.replace
|
||||
return dist.delete + dist.insert + 2 * dist.replace
|
||||
|
||||
|
||||
@lru_cache(1000000)
|
||||
def calculate_penalty(
|
||||
gt: "Part", ocr: "Part", match: Match, coef: Coefficients
|
||||
gt_length: int,
|
||||
ocr_length: int,
|
||||
gt_start: int,
|
||||
ocr_start: int,
|
||||
gt_match_start: int,
|
||||
ocr_match_start: int,
|
||||
dist: Distance,
|
||||
coef: Coefficients,
|
||||
) -> float:
|
||||
"""Calculate the penalty for a given match.
|
||||
|
||||
|
@ -297,12 +317,12 @@ def calculate_penalty(
|
|||
|
||||
:return: Penalty for the given match.
|
||||
"""
|
||||
min_edit_dist = score_edit_distance(match)
|
||||
length_diff = abs(gt.length - ocr.length)
|
||||
substring_length = min(gt.length, ocr.length)
|
||||
min_edit_dist = score_edit_distance(dist)
|
||||
length_diff = abs(gt_length - ocr_length)
|
||||
substring_length = min(gt_length, ocr_length)
|
||||
offset = 0.0
|
||||
if length_diff > 1:
|
||||
substring_pos = max(match.gt.start - gt.start, match.ocr.start - ocr.start)
|
||||
substring_pos = max(gt_match_start - gt_start, ocr_match_start - ocr_start)
|
||||
offset = length_diff / 2 - abs(substring_pos - length_diff / 2)
|
||||
return (
|
||||
min_edit_dist * coef.edit_dist
|
||||
|
@ -428,4 +448,4 @@ class Part(PartVersionSpecific):
|
|||
"""
|
||||
text = self.text[rel_start:rel_end]
|
||||
start = self.start + rel_start
|
||||
return Part(**{**self._asdict(), "text": text, "start": start})
|
||||
return Part(line=self.line, text=text, start=start)
|
||||
|
|
|
@ -70,9 +70,9 @@ SIMPLE_EDITS = [
|
|||
(Part(text="a"), Part(text="a"), Distance(match=1)),
|
||||
(Part(text="aaa"), Part(text="aaa"), Distance(match=3)),
|
||||
(
|
||||
Part(text="abcd"),
|
||||
Part(text="beed"),
|
||||
Distance(match=2, replace=1, insert=1, delete=1),
|
||||
Part(text="abbbbcd"),
|
||||
Part(text="bbbbede"),
|
||||
Distance(match=5, replace=1, insert=1, delete=1),
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
@ -50,7 +50,6 @@ def test_reading_order_settings(file, expected_text):
|
|||
assert ocr == expected_text
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need to check performance first.")
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"gt,ocr,expected",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue