@ -17,9 +17,10 @@ from functools import lru_cache, reduce
from itertools import product , takewhile
from itertools import product , takewhile
from typing import List , Tuple , Optional
from typing import List , Tuple , Optional
from Levenshtein import editops
from multimethod import multimethod
from multimethod import multimethod
from . import editops, ExtractedText
from . import ExtractedText
if sys . version_info . minor == 5 :
if sys . version_info . minor == 5 :
from . flexible_character_accuracy_ds_35 import (
from . flexible_character_accuracy_ds_35 import (
@ -170,10 +171,21 @@ def match_gt_line(
"""
"""
min_penalty = float ( " inf " )
min_penalty = float ( " inf " )
best_match , best_ocr = None , None
best_match , best_ocr = None , None
gt_line_length = gt_line . length
gt_line_start = gt_line . start
for ocr_line in ocr_lines :
for ocr_line in ocr_lines :
match = match_lines ( gt_line , ocr_line )
match = match_lines ( gt_line , ocr_line )
if match :
if match :
penalty = calculate_penalty ( gt_line , ocr_line , match , coef )
penalty = calculate_penalty (
gt_line_length ,
ocr_line . length ,
gt_line_start ,
ocr_line . start ,
match . gt . start ,
match . ocr . start ,
match . dist ,
coef ,
)
if penalty < min_penalty :
if penalty < min_penalty :
min_penalty , best_match , best_ocr = penalty , match , ocr_line
min_penalty , best_match , best_ocr = penalty , match , ocr_line
return best_match , best_ocr
return best_match , best_ocr
@ -234,7 +246,7 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
for i , gt_part in gt_parts :
for i , gt_part in gt_parts :
for j , ocr_part in ocr_parts :
for j , ocr_part in ocr_parts :
match = distance ( gt_part , ocr_part )
match = distance ( gt_part , ocr_part )
edit_dist = score_edit_distance ( match )
edit_dist = score_edit_distance ( match . dist )
if edit_dist < min_edit_dist and match . dist . replace < min_length :
if edit_dist < min_edit_dist and match . dist . replace < min_length :
min_edit_dist = edit_dist
min_edit_dist = edit_dist
best_match = match
best_match = match
@ -248,13 +260,13 @@ def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]:
gt_line . substring ( rel_start = best_i , rel_end = best_i + k ) ,
gt_line . substring ( rel_start = best_i , rel_end = best_i + k ) ,
ocr_line . substring ( rel_start = best_j , rel_end = best_j + k ) ,
ocr_line . substring ( rel_start = best_j , rel_end = best_j + k ) ,
)
)
edit_dist = score_edit_distance ( match )
edit_dist = score_edit_distance ( match . dist )
if edit_dist < min_edit_dist and match . dist . replace < min_length :
if edit_dist < min_edit_dist and match . dist . replace < min_length :
min_edit_dist = edit_dist
min_edit_dist = edit_dist
best_match = match
best_match = match
# is delete a better option?
# is delete a better option?
match = distance ( gt_line , Part ( text = " " , line = ocr_line . line , start = ocr_line . start ) )
match = distance ( gt_line , Part ( text = " " , line = ocr_line . line , start = ocr_line . start ) )
edit_dist = score_edit_distance ( match )
edit_dist = score_edit_distance ( match . dist )
if edit_dist < min_edit_dist :
if edit_dist < min_edit_dist :
best_match = match
best_match = match
@ -278,18 +290,26 @@ def distance(gt: "Part", ocr: "Part") -> Match:
return Match ( gt = gt , ocr = ocr , dist = Distance ( * * edits ) , ops = ops )
return Match ( gt = gt , ocr = ocr , dist = Distance ( * * edits ) , ops = ops )
def score_edit_distance ( match: Match ) - > int :
def score_edit_distance ( dist: Distance ) - > int :
""" Calculate edit distance for a match.
""" Calculate edit distance for a match.
Formula : $ deletes + inserts + 2 * replacements $
Formula : $ deletes + inserts + 2 * replacements $
: return : Sum of deletes , inserts and replacements .
: return : Sum of deletes , inserts and replacements .
"""
"""
return match. dist. delete + match. dist. insert + 2 * match . dist . replace
return dist. delete + dist. insert + 2 * dist . replace
@lru_cache ( 1000000 )
def calculate_penalty (
def calculate_penalty (
gt : " Part " , ocr : " Part " , match : Match , coef : Coefficients
gt_length : int ,
ocr_length : int ,
gt_start : int ,
ocr_start : int ,
gt_match_start : int ,
ocr_match_start : int ,
dist : Distance ,
coef : Coefficients ,
) - > float :
) - > float :
""" Calculate the penalty for a given match.
""" Calculate the penalty for a given match.
@ -297,12 +317,12 @@ def calculate_penalty(
: return : Penalty for the given match .
: return : Penalty for the given match .
"""
"""
min_edit_dist = score_edit_distance ( match )
min_edit_dist = score_edit_distance ( dist )
length_diff = abs ( gt . length - ocr . length)
length_diff = abs ( gt _length - ocr_ length)
substring_length = min ( gt . length , ocr . length)
substring_length = min ( gt _length, ocr_ length)
offset = 0.0
offset = 0.0
if length_diff > 1 :
if length_diff > 1 :
substring_pos = max ( match. gt . start - gt . start , match . ocr . start - ocr . start)
substring_pos = max ( gt_match_start - gt_start , ocr_match_start - ocr_ start)
offset = length_diff / 2 - abs ( substring_pos - length_diff / 2 )
offset = length_diff / 2 - abs ( substring_pos - length_diff / 2 )
return (
return (
min_edit_dist * coef . edit_dist
min_edit_dist * coef . edit_dist
@ -428,4 +448,4 @@ class Part(PartVersionSpecific):
"""
"""
text = self . text [ rel_start : rel_end ]
text = self . text [ rel_start : rel_end ]
start = self . start + rel_start
start = self . start + rel_start
return Part ( * * { * * self . _asdict ( ) , " text " : text , " start " : start } )
return Part ( line = self . line , text = text , start = start )