diff --git a/src/dinglehopper/cli_line_dirs.py b/src/dinglehopper/cli_line_dirs.py index 01fd585..43e4f1a 100644 --- a/src/dinglehopper/cli_line_dirs.py +++ b/src/dinglehopper/cli_line_dirs.py @@ -1,5 +1,6 @@ import itertools import os +from typing import Iterator, Tuple import click from jinja2 import Environment, FileSystemLoader @@ -12,11 +13,36 @@ from .ocr_files import plain_extract from .word_error_rate import word_error_rate_n, words_normalized +def removesuffix(text, suffix): + if suffix and text.endswith(suffix): + return text[: -len(suffix)] + return text + +def is_hidden(filepath): + filename = os.path.basename(os.path.abspath(filepath)) + return filename.startswith(".") + +def find_all_files(dir_: str, pred=None, return_hidden=False) -> Iterator[str]: + """ + Find all files in dir_, returning filenames + + If pred is given, pred(filename) must be True for the filename. + + Does not return hidden files by default. + """ + for root, _, filenames in os.walk(dir_): + for fn in filenames: + if not return_hidden and is_hidden(fn): + continue + if pred and not pred(fn): + continue + yield os.path.join(root, fn) + + def all_equal(iterable): g = itertools.groupby(iterable) return next(g, True) and not next(g, False) - def common_prefix(its): return [p[0] for p in itertools.takewhile(all_equal, zip(*its))] @@ -24,16 +50,49 @@ def common_prefix(its): def common_suffix(its): return reversed(common_prefix(reversed(it) for it in its)) +def find_gt_and_ocr_files(gt_dir, gt_suffix, ocr_dir, ocr_suffix) -> Iterator[Tuple[str, str]]: + """ + Find GT files and matching OCR files. -def removesuffix(text, suffix): - if suffix and text.endswith(suffix): - return text[: -len(suffix)] - return text + Returns pairs of GT and OCR files. + """ + for gt_fn in find_all_files(gt_dir, lambda fn: fn.endswith(gt_suffix)): + ocr_fn = os.path.join( + ocr_dir, + os.path.relpath(gt_fn, start=gt_dir).removesuffix(gt_suffix) + + ocr_suffix, + ) + if not os.path.exists(ocr_fn): + raise RuntimeError(f"{ocr_fn} (matching {gt_fn}) does not exist") + + yield gt_fn, ocr_fn + + +def find_gt_and_ocr_files_autodetect(gt_dir, ocr_dir): + """ + Find GT files and matching OCR files, autodetect suffixes. + + This only works if gt_dir (or respectivley ocr_dir) only contains GT (OCR) + files with a common suffix. Currently the files must have a suffix, e.g. + ".gt.txt" (e.g. ".ocr.txt"). + + Returns pairs of GT and OCR files. + """ + + # Autodetect suffixes + gt_files = find_all_files(gt_dir) + gt_suffix = "".join(common_suffix(gt_files)) + if len(gt_suffix) == 0: + raise RuntimeError(f"Files in GT directory {gt_dir} do not have a common suffix") + ocr_files = find_all_files(ocr_dir) + ocr_suffix = "".join(common_suffix(ocr_files)) + if len(ocr_suffix) == 0: + raise RuntimeError(f"Files in OCR directory {ocr_dir} do not have a common suffix") + + yield from find_gt_and_ocr_files(gt_dir, gt_suffix, ocr_dir, ocr_suffix) def process(gt_dir, ocr_dir, report_prefix, *, metrics=True): - gt_suffix = "".join(common_suffix(os.listdir(gt_dir))) - ocr_suffix = "".join(common_suffix(os.listdir(ocr_dir))) cer = None n_characters = None @@ -42,14 +101,10 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True): n_words = None word_diff_report = "" - for k, gt in enumerate(os.listdir(gt_dir)): - # Find a match by replacing the suffix - ocr = removesuffix(gt, gt_suffix) + ocr_suffix + for k, (gt_fn, ocr_fn) in enumerate(find_gt_and_ocr_files_autodetect(gt_dir, ocr_dir)): - gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True) - ocr_text = plain_extract( - os.path.join(ocr_dir, ocr), include_filename_in_id=True - ) + gt_text = plain_extract(gt_fn, include_filename_in_id=True) + ocr_text = plain_extract(ocr_fn, include_filename_in_id=True) gt_words = words_normalized(gt_text) ocr_words = words_normalized(ocr_text) diff --git a/src/dinglehopper/line_dirs_test.py b/src/dinglehopper/line_dirs_test.py index 676fe22..9827f01 100644 --- a/src/dinglehopper/line_dirs_test.py +++ b/src/dinglehopper/line_dirs_test.py @@ -2,78 +2,7 @@ import os.path import itertools from typing import Iterator, Tuple -def is_hidden(filepath): - filename = os.path.basename(os.path.abspath(filepath)) - return filename.startswith(".") - -def find_all_files(dir_: str, pred=None, return_hidden=False) -> Iterator[str]: - """ - Find all files in dir_, returning filenames - - If pred is given, pred(filename) must be True for the filename. - - Does not return hidden files by default. - """ - for root, _, filenames in os.walk(dir_): - for fn in filenames: - if not return_hidden and is_hidden(fn): - continue - if pred and not pred(fn): - continue - yield os.path.join(root, fn) - - -def find_gt_and_ocr_files(gt_dir, gt_suffix, ocr_dir, ocr_suffix) -> Iterator[Tuple[str, str]]: - """ - Find GT files and matching OCR files. - - Returns pairs of GT and OCR files. - """ - for gt_fn in find_all_files(gt_dir, lambda fn: fn.endswith(gt_suffix)): - ocr_fn = os.path.join( - ocr_dir, - os.path.relpath(gt_fn, start=gt_dir).removesuffix(gt_suffix) - + ocr_suffix, - ) - if not os.path.exists(ocr_fn): - raise RuntimeError(f"{ocr_fn} (matching {gt_fn}) does not exist") - - yield gt_fn, ocr_fn - -def all_equal(iterable): - g = itertools.groupby(iterable) - return next(g, True) and not next(g, False) - -def common_prefix(its): - return [p[0] for p in itertools.takewhile(all_equal, zip(*its))] - - -def common_suffix(its): - return reversed(common_prefix(reversed(it) for it in its)) - - -def find_gt_and_ocr_files_autodetect(gt_dir, ocr_dir): - """ - Find GT files and matching OCR files, autodetect suffixes. - - This only works if gt_dir (or respectivley ocr_dir) only contains GT (OCR) - files with a common suffix. Currently the files must have a suffix, e.g. - ".gt.txt" (e.g. ".ocr.txt"). - - Returns pairs of GT and OCR files. - """ - - # Autodetect suffixes - gt_files = find_all_files(gt_dir) - gt_suffix = "".join(common_suffix(gt_files)) - if len(gt_suffix) == 0: - raise RuntimeError(f"Files in GT directory {gt_dir} do not have a common suffix") - ocr_files = find_all_files(ocr_dir) - ocr_suffix = "".join(common_suffix(ocr_files)) - if len(ocr_suffix) == 0: - raise RuntimeError(f"Files in OCR directory {ocr_dir} do not have a common suffix") - yield from find_gt_and_ocr_files(gt_dir, gt_suffix, ocr_dir, ocr_suffix) def test_basic():