You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
dinglehopper/dinglehopper/cli_line_dirs.py

147 lines
4.8 KiB
Python

import os
import sys
import itertools
import click
from jinja2 import Environment, FileSystemLoader
from markupsafe import escape
from uniseg.graphemecluster import grapheme_clusters
from ocrd_utils import initLogging
from .character_error_rate import character_error_rate_n
from .word_error_rate import word_error_rate_n, words_normalized
from .align import seq_align
from .extracted_text import ExtractedText
from .ocr_files import plain_extract
from .config import Config
from .cli import gen_diff_report, json_float
def all_equal(iterable):
g = itertools.groupby(iterable)
return next(g, True) and not next(g, False)
def common_prefix(its):
return [p[0] for p in itertools.takewhile(all_equal, zip(*its))]
def common_suffix(its):
return reversed(common_prefix(reversed(it) for it in its))
def removesuffix(text, suffix):
if suffix and text.endswith(suffix):
return text[: -len(suffix)]
return text
def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
gt_suffix = "".join(common_suffix(os.listdir(gt_dir)))
ocr_suffix = "".join(common_suffix(os.listdir(ocr_dir)))
cer = None
n_characters = None
char_diff_report = ""
wer = None
n_words = None
word_diff_report = ""
for k, gt in enumerate(os.listdir(gt_dir)):
# Find a match by replacing the suffix
ocr = removesuffix(gt, gt_suffix) + ocr_suffix
gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True)
ocr_text = plain_extract(
os.path.join(ocr_dir, ocr), include_filename_in_id=True
)
# Compute CER
l_cer, l_n_characters = character_error_rate_n(gt_text, ocr_text)
if cer is None:
cer, n_characters = l_cer, l_n_characters
else:
# Rolling update
cer = (cer * n_characters + l_cer * l_n_characters) / (
n_characters + l_n_characters
)
n_characters = n_characters + l_n_characters
# Compute WER
l_wer, l_n_words = word_error_rate_n(gt_text, ocr_text)
if wer is None:
wer, n_words = l_wer, l_n_words
else:
# Rolling update
wer = (wer * n_words + l_wer * l_n_words) / (n_words + l_n_words)
n_words = n_words + l_n_words
# Generate diff reports
char_diff_report += gen_diff_report(
gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·"
)
gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text)
word_diff_report += gen_diff_report(
gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none=""
)
env = Environment(
loader=FileSystemLoader(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "templates")
)
)
env.filters["json_float"] = json_float
for report_suffix in (".html", ".json"):
template_fn = "report" + report_suffix + ".j2"
out_fn = report_prefix + report_suffix
template = env.get_template(template_fn)
template.stream(
gt=gt_dir, # Note: directory
ocr=ocr_dir, # Note: directory
cer=cer,
n_characters=n_characters,
wer=wer,
n_words=n_words,
char_diff_report=char_diff_report,
word_diff_report=word_diff_report,
metrics=metrics,
).dump(out_fn)
@click.command()
@click.argument("gt", type=click.Path(exists=True))
@click.argument("ocr", type=click.Path(exists=True))
@click.argument("report_prefix", type=click.Path(), default="report")
@click.option(
"--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red"
)
def main(gt, ocr, report_prefix, metrics):
"""
Compare the GT line text directory against the OCR line text directory.
This assumes that the GT line text directory contains textfiles with a common
suffix like ".gt.txt", and the OCR line text directory contains textfiles with
a common suffix like ".some-ocr.txt". The text files also need to be paired,
i.e. the GT file "line001.gt.txt" needs to match a file "line001.some-ocr.txt"
in the OCT lines directory.
The GT and OCR directories are usually round truth line texts and the results of
an OCR software, but you may use dinglehopper to compare two OCR results. In
that case, use --no-metrics to disable the then meaningless metrics and also
change the color scheme from green/red to blue.
The comparison report will be written to $REPORT_PREFIX.{html,json}, where
$REPORT_PREFIX defaults to "report". The reports include the character error
rate (CER) and the word error rate (WER).
"""
initLogging()
process(gt, ocr, report_prefix, metrics=metrics)
if __name__ == "__main__":
main()