update rapidfuzz version

pull/72/head
Max Bachmann 2 years ago
parent a1f0a5e2d3
commit d2bbc8a6c7

@ -1,7 +1,6 @@
from .edit_distance import *
from rapidfuzz.distance import Levenshtein
def align(t1, t2):
"""Align text."""
s1 = list(grapheme_clusters(unicodedata.normalize("NFC", t1)))
@ -9,11 +8,11 @@ def align(t1, t2):
return seq_align(s1, s2)
def seq_align(s1, s2):
def seq_align(s1, s2, score_hint=None):
"""Align general sequences."""
s1 = list(s1)
s2 = list(s2)
ops = Levenshtein.editops(s1, s2)
ops = Levenshtein.editops(s1, s2, score_hint=score_hint)
i = 0
j = 0

@ -4,6 +4,7 @@ import click
from jinja2 import Environment, FileSystemLoader
from markupsafe import escape
from ocrd_utils import initLogging
from math import ceil
from .character_error_rate import character_error_rate_n
from .word_error_rate import word_error_rate_n, words_normalized
@ -13,7 +14,7 @@ from .ocr_files import extract
from .config import Config
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, score_hint=None):
gtx = ""
ocrx = ""
@ -52,7 +53,7 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
g_pos = 0
o_pos = 0
for k, (g, o) in enumerate(seq_align(gt_things, ocr_things)):
for k, (g, o) in enumerate(seq_align(gt_things, ocr_things, score_hint)):
css_classes = None
gt_id = None
ocr_id = None
@ -109,12 +110,12 @@ def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"):
cer, n_characters = character_error_rate_n(gt_text, ocr_text)
char_diff_report = gen_diff_report(
gt_text, ocr_text, css_prefix="c", joiner="", none="·"
gt_text, ocr_text, css_prefix="c", joiner="", none="·", score_hint=int(ceil(cer * n_characters))
)
wer, n_words = word_error_rate_n(gt_words, ocr_words)
word_diff_report = gen_diff_report(
gt_words, ocr_words, css_prefix="w", joiner=" ", none=""
gt_words, ocr_words, css_prefix="w", joiner=" ", none="", score_hint=int(ceil(wer * n_words))
)
env = Environment(
@ -175,24 +176,6 @@ def main(gt, ocr, report_prefix, metrics, textequiv_level, progress):
By default, the text of PAGE files is extracted on 'region' level. You may
use "--textequiv-level line" to extract from the level of TextLine tags.
"""
import cProfile
import pstats
import io
import atexit
#print("Profiling...")
#pr = cProfile.Profile()
#pr.enable()
def exit():
pr.disable()
print("Profiling completed")
s = io.StringIO()
pstats.Stats(pr, stream=s).sort_stats("cumtime").print_stats()
print(s.getvalue())
#atexit.register(exit)
initLogging()
Config.progress = progress
process(gt, ocr, report_prefix, metrics=metrics, textequiv_level=textequiv_level)

@ -4,6 +4,7 @@ import itertools
import click
from jinja2 import Environment, FileSystemLoader
from ocrd_utils import initLogging
from math import ceil
from .character_error_rate import character_error_rate_n
from .word_error_rate import word_error_rate_n, words_normalized
@ -74,10 +75,10 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
# Generate diff reports
char_diff_report += gen_diff_report(
gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·"
gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·", score_hint=int(ceil(l_cer * l_n_characters))
)
word_diff_report += gen_diff_report(
gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none=""
gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none="", score_hint=int(ceil(l_wer * l_n_words))
)
env = Environment(

@ -9,5 +9,5 @@ ocrd >= 2.20.1
attrs
multimethod == 1.3 # latest version to officially support Python 3.5
tqdm
rapidfuzz >= 2.4.2
rapidfuzz >= 2.7.0
six # XXX workaround OCR-D/core#730

Loading…
Cancel
Save