dinglehopper: Include number of characters and words in JSON report

pull/29/head
Gerber, Mike 5 years ago
parent be251a391e
commit 779472575c

@ -1,21 +1,36 @@
from __future__ import division from __future__ import division
import unicodedata import unicodedata
from typing import Tuple
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from qurator.dinglehopper.edit_distance import distance from qurator.dinglehopper.edit_distance import distance
def character_error_rate(reference, compared): def character_error_rate_n(reference, compared) -> Tuple[float, int]:
d = distance(reference, compared) """
if d == 0: Compute character error rate.
return 0
:return: character error rate and length of the reference
"""
d = distance(reference, compared)
n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference)))) n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference))))
if n == 0:
return float('inf')
return d/n if d == 0:
return 0, n
if n == 0:
return float('inf'), n
return d/n, n
# XXX Should we really count newlines here? # XXX Should we really count newlines here?
def character_error_rate(reference, compared) -> float:
"""
Compute character error rate.
:return: character error rate
"""
cer, _ = character_error_rate_n(reference, compared)
return cer

@ -57,8 +57,8 @@ def process(gt, ocr, report_prefix):
gt_text = substitute_equivalences(gt_text) gt_text = substitute_equivalences(gt_text)
ocr_text = substitute_equivalences(ocr_text) ocr_text = substitute_equivalences(ocr_text)
cer = character_error_rate(gt_text, ocr_text) cer, n_characters = character_error_rate_n(gt_text, ocr_text)
wer = word_error_rate(gt_text, ocr_text) wer, n_words = word_error_rate_n(gt_text, ocr_text)
char_diff_report = gen_diff_report(gt_text, ocr_text, css_prefix='c', joiner='', none='·', align=align) char_diff_report = gen_diff_report(gt_text, ocr_text, css_prefix='c', joiner='', none='·', align=align)
@ -88,7 +88,8 @@ def process(gt, ocr, report_prefix):
template = env.get_template(template_fn) template = env.get_template(template_fn)
template.stream( template.stream(
gt=gt, ocr=ocr, gt=gt, ocr=ocr,
cer=cer, wer=wer, cer=cer, n_characters=n_characters,
wer=wer, n_words=n_words,
char_diff_report=char_diff_report, char_diff_report=char_diff_report,
word_diff_report=word_diff_report word_diff_report=word_diff_report
).dump(out_fn) ).dump(out_fn)

@ -2,5 +2,7 @@
"gt": "{{ gt }}", "gt": "{{ gt }}",
"ocr": "{{ ocr }}", "ocr": "{{ ocr }}",
"cer": {{ cer|json_float }}, "cer": {{ cer|json_float }},
"wer": {{ wer|json_float }} "wer": {{ wer|json_float }},
"n_characters": {{ n_characters }},
"n_words": {{ n_words }}
} }

@ -1,6 +1,7 @@
from __future__ import division from __future__ import division
import unicodedata import unicodedata
from typing import Tuple
import uniseg.wordbreak import uniseg.wordbreak
@ -44,7 +45,7 @@ def words_normalized(s):
return words(unicodedata.normalize('NFC', s)) return words(unicodedata.normalize('NFC', s))
def word_error_rate(reference, compared): def word_error_rate_n(reference, compared) -> Tuple[float, int]:
if isinstance(reference, str): if isinstance(reference, str):
reference_seq = list(words_normalized(reference)) reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared)) compared_seq = list(words_normalized(compared))
@ -53,11 +54,15 @@ def word_error_rate(reference, compared):
compared_seq = list(compared) compared_seq = list(compared)
d = levenshtein(reference_seq, compared_seq) d = levenshtein(reference_seq, compared_seq)
if d == 0:
return 0
n = len(reference_seq) n = len(reference_seq)
if d == 0:
return 0, n
if n == 0: if n == 0:
return float('inf') return float('inf'), n
return d / n, n
return d / n def word_error_rate(reference, compared) -> float:
wer, _ = word_error_rate_n(reference, compared)
return wer

Loading…
Cancel
Save