diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..7aecdd0 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,20 @@ +version: 2.1 + +jobs: + black: + parameters: + python-version: + type: string + docker: + - image: cimg/python:<< parameters.python-version >> + steps: + - checkout + - run: pip3 install --upgrade pip + - run: pip3 install black + - run: black . + +workflows: + black: + jobs: + - black: + python-version: "3.11" diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index a8312db..0000000 --- a/.dockerignore +++ /dev/null @@ -1,5 +0,0 @@ -src/dinglehopper/tests -dist -build -*.egg-info -.git diff --git a/.editorconfig b/.editorconfig index 6959d70..ea42d71 100644 --- a/.editorconfig +++ b/.editorconfig @@ -15,7 +15,7 @@ indent_size = 2 [*.json] indent_size = 2 -insert_final_newline = true +insert_final_newline = false # trailing spaces in markdown indicate word wrap [*.md] diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3f51bd7..8c193df 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v3 - name: Upgrade pip run: python3 -m pip install --upgrade pip - name: Install setuptools @@ -32,7 +32,7 @@ jobs: - name: Build package run: python3 -m pip install --upgrade build && python3 -m build - name: Upload dist - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v3 with: name: dist path: dist/ @@ -42,7 +42,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Download dist - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v3 with: name: dist path: dist/ @@ -61,7 +61,7 @@ jobs: id-token: write # IMPORTANT: this permission is mandatory for trusted publishing steps: - name: Download dist - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v3 with: name: dist path: dist/ diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index db089d0..7d55459 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: 'Test' +name: test on: @@ -6,10 +6,6 @@ on: branches: - master - pull_request: - branches: - - master - schedule: - cron: "00 16 07 * *" # = monthly @@ -25,27 +21,30 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12", "3.13" ] + python-version: [ "3.6", "3.7", "3.8", "3.9", "3.10", "3.11" ] - runs-on: "ubuntu-latest" + # For Python 3.6, we need to fall back to Ubuntu 20.04 + runs-on: ${{ matrix.python-version == '3.6' && 'ubuntu-20.04' || 'ubuntu-latest' }} + + env: + test_results_dir: test-results-${{ matrix.python-version }} steps: - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - allow-prereleases: true - name: Checkout - uses: actions/checkout@v4 - - - name: Install possible lxml build requirements (if building from source) - run: sudo apt-get install -y libxml2-dev libxslt-dev python3-dev - - name: Install possible shapely build requirements (if building from source) - run: sudo apt-get install -y libgeos-dev + uses: actions/checkout@v3 - name: Update pip run: python3 -m pip install -U pip + - name: Avoid compiling OpenCV and NumPy on Python 3.6 + run: | + if python3 --version | grep -q "Python 3.6"; then + pip install --prefer-binary -U opencv-python-headless numpy + fi - name: Install requirements*.txt run: | for requirements_txt in requirements*.txt; do @@ -55,10 +54,19 @@ jobs: - name: Test run: | cd src - python3 -m pytest --junitxml=../${{matrix.python-version}}-junit.xml -o junit_family=legacy + mkdir -p ../$test_results_dir + python3 -m pytest --junitxml=../$test_results_dir/junit.xml -o junit_family=legacy - name: Upload test results - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v3 if: success() || failure() with: - name: test-results-${{matrix.python-version}} - path: ${{matrix.python-version}}-junit.xml + name: ${{ env.test_results_dir }} + path: ${{ env.test_results_dir }} + + - name: Report tests + uses: dorny/test-reporter@v1 + if: success() || failure() + with: + name: Results on Python ${{ matrix.python-version }} + path: "${{env.test_results_dir }}/junit.xml" + reporter: java-junit diff --git a/.github/workflows/test_report.yml b/.github/workflows/test_report.yml deleted file mode 100644 index 5579d8c..0000000 --- a/.github/workflows/test_report.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: 'Test - Report results' -on: - workflow_run: - workflows: ['test'] - types: - - completed -permissions: - contents: read - actions: read - checks: write -jobs: - report: - runs-on: ubuntu-latest - steps: - - uses: dorny/test-reporter@v1 - with: - artifact: /test-results-(.*)/ - name: 'test - Results ($1)' - path: '*junit.xml' - reporter: java-junit diff --git a/.gitignore b/.gitignore index 66d66bc..2291cd6 100644 --- a/.gitignore +++ b/.gitignore @@ -25,8 +25,6 @@ dmypy.json # User-specific stuff .idea -.*.swp # Build artifacts /build -/dist diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index bdcb93a..0000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,16 +0,0 @@ -variables: - http_proxy: "http://http-proxy.sbb.spk-berlin.de:3128/" - https_proxy: "http://http-proxy.sbb.spk-berlin.de:3128/" - HTTP_PROXY: "http://http-proxy.sbb.spk-berlin.de:3128/" - HTTPS_PROXY: "http://http-proxy.sbb.spk-berlin.de:3128/" - -stages: - - triggers - -mirror: - stage: triggers - trigger: - include: .gitlab/mirror.yml - strategy: depend - rules: - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH diff --git a/.gitlab/mirror.yml b/.gitlab/mirror.yml deleted file mode 100644 index f3591a2..0000000 --- a/.gitlab/mirror.yml +++ /dev/null @@ -1,47 +0,0 @@ -stages: - - check - - pull - - push - -default: - image: debian - - -check: - stage: check - - script: - - whoami; env - - if [ -z "$CI_COMMIT_BRANCH" ]; then echo "Not on a branch" >&2; exit 3; fi - - -pull-gitlab: - stage: pull - script: - - echo "This is redundant" - -pull-github: - stage: pull - before_script: - - apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* - script: - - git remote remove github 2>/dev/null || true - - git remote add github https://github.com/qurator-spk/dinglehopper.git - - git remote -v - - - git pull github "$CI_COMMIT_BRANCH" - - -push-gitlab: - stage: push - before_script: - - apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* - script: - - git push origin "$CI_COMMIT_SHA":"$CI_COMMIT_BRANCH" - -push-github: - stage: push - before_script: - - apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* - script: - - git push github "$CI_COMMIT_SHA":"$CI_COMMIT_BRANCH" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 345060d..dd7b710 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,8 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v3.2.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -11,37 +13,17 @@ repos: - id: check-ast - repo: https://github.com/psf/black - rev: 25.1.0 + rev: 22.10.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.7 + rev: v0.0.280 hooks: - - args: - - --fix - - --exit-non-zero-on-fix - id: ruff + - id: ruff + args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.4.1 hooks: - - additional_dependencies: - - types-setuptools - - types-lxml - - numpy # for numpy plugin - - attrs - - multimethod - - rapidfuzz - id: mypy - -- repo: https://gitlab.com/vojko.pribudic.foss/pre-commit-update - rev: v0.6.1 - hooks: - - id: pre-commit-update - -- repo: https://github.com/dhatim/python-license-check - rev: 0.9.2 - hooks: - - id: liccheck - language: system + - id: mypy diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 7064efc..0000000 --- a/Dockerfile +++ /dev/null @@ -1,40 +0,0 @@ -ARG DOCKER_BASE_IMAGE -FROM $DOCKER_BASE_IMAGE -ARG VCS_REF -ARG BUILD_DATE -LABEL \ - maintainer="https://github.com/qurator-spk/dinglehopper/issues" \ - org.label-schema.vcs-ref=$VCS_REF \ - org.label-schema.vcs-url="https://github.com/qurator-spk/dinglehopper" \ - org.label-schema.build-date=$BUILD_DATE \ - org.opencontainers.image.vendor="Staatsbibliothek zu Berlin — SPK" \ - org.opencontainers.image.title="dinglehopper" \ - org.opencontainers.image.description="An OCR evaluation tool" \ - org.opencontainers.image.source="https://github.com/qurator-spk/dinglehopper" \ - org.opencontainers.image.documentation="https://github.com/qurator-spk/dinglehopper/blob/${VCS_REF}/README.md" \ - org.opencontainers.image.revision=$VCS_REF \ - org.opencontainers.image.created=$BUILD_DATE \ - org.opencontainers.image.base.name=ocrd/core - -ENV LANG=C.UTF-8 -ENV LC_ALL=C.UTF-8 - -# avoid HOME/.local/share (hard to predict USER here) -# so let XDG_DATA_HOME coincide with fixed system location -# (can still be overridden by derived stages) -ENV XDG_DATA_HOME /usr/local/share -# avoid the need for an extra volume for persistent resource user db -# (i.e. XDG_CONFIG_HOME/ocrd/resources.yml) -ENV XDG_CONFIG_HOME /usr/local/share/ocrd-resources - -WORKDIR /build/dinglehopper -COPY . . -COPY ocrd-tool.json . -# prepackage ocrd-tool.json as ocrd-all-tool.json -RUN ocrd ocrd-tool ocrd-tool.json dump-tools > $(dirname $(ocrd bashlib filename))/ocrd-all-tool.json -# prepackage ocrd-all-module-dir.json -RUN ocrd ocrd-tool ocrd-tool.json dump-module-dirs > $(dirname $(ocrd bashlib filename))/ocrd-all-module-dir.json -RUN make install && rm -rf /build/dinglehopper - -WORKDIR /data -VOLUME /data diff --git a/LICENSE b/LICENSE index 221c706..9b7a833 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2019-2025 Staatsbibliothek zu Berlin — SPK + Copyright 2019 qurator Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/Makefile b/Makefile deleted file mode 100644 index 3729311..0000000 --- a/Makefile +++ /dev/null @@ -1,34 +0,0 @@ -PYTHON = python3 -PIP = pip3 -PYTHONIOENCODING=utf8 -PYTEST_ARGS = -vv - -DOCKER_BASE_IMAGE ?= docker.io/ocrd/core:latest -DOCKER_TAG ?= ocrd/dinglehopper -DOCKER ?= docker - -help: - @echo - @echo " Targets" - @echo - @echo " install Install full Python package via pip" - @echo " docker Build the ocrd/dinglehopper docker image" - -# Install Python package via pip -install: - $(PIP) install . - -install-dev: - $(PIP) install -e . - -test: - pytest $(PYTEST_ARGS) - -docker: - $(DOCKER) build \ - --build-arg DOCKER_BASE_IMAGE=$(DOCKER_BASE_IMAGE) \ - --build-arg VCS_REF=$$(git rev-parse --short HEAD) \ - --build-arg BUILD_DATE=$$(date -u +"%Y-%m-%dT%H:%M:%SZ") \ - -t $(DOCKER_TAG) . - -.PHONY: help install install-dev test docker diff --git a/README-DEV.md b/README-DEV.md index 3ec432f..cdd51fd 100644 --- a/README-DEV.md +++ b/README-DEV.md @@ -10,7 +10,6 @@ pytest ``` ## Test running examples - Only unit tests: ```bash pytest -m "not integration" @@ -37,21 +36,9 @@ pytest -k "not test" --mypy pytest -k "not test" --ruff ``` -# How to use pre-commit +## How to use pre-commit This project optionally uses [pre-commit](https://pre-commit.com) to check commits. To use it: - Install pre-commit, e.g. `pip install -r requirements-dev.txt` - Install the repo-local git hooks: `pre-commit install` - - -# Releasing a new version - -- Update `ocrd-tool.json` -- `git commit` -- `git tag vx.y.z` -- `git push && git push --tags` -- The GitHub Actions workflow `release` will now create - a. a new release on GitHub and - b. a new release on PyPI -- Currently requires a review for PYPI? diff --git a/README.md b/README.md index a40db79..affcfe8 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,10 @@ dinglehopper is an OCR evaluation tool and reads [ALTO](https://github.com/altoxml), [PAGE](https://github.com/PRImA-Research-Lab/PAGE-XML) and text files. It compares a ground truth (GT) document page with a OCR result page to compute -metrics and a word/character differences report. It also supports batch processing by +metrics and a word/character differences report. It also supports batch processing by generating, aggregating and summarizing multiple reports. -[![Tests](https://github.com/qurator-spk/dinglehopper/actions/workflows/test.yml/badge.svg)](https://github.com/qurator-spk/dinglehopper/actions?query=workflow:"test") +[![Tests](https://github.com/qurator-spk/dinglehopper/workflows/test/badge.svg)](https://github.com/qurator-spk/dinglehopper/actions?query=workflow:"test") [![GitHub tag](https://img.shields.io/github/tag/qurator-spk/dinglehopper?include_prereleases=&sort=semver&color=blue)](https://github.com/qurator-spk/dinglehopper/releases/) [![License](https://img.shields.io/badge/License-Apache-blue)](#license) [![issues - dinglehopper](https://img.shields.io/github/issues/qurator-spk/dinglehopper)](https://github.com/qurator-spk/dinglehopper/issues) @@ -23,11 +23,10 @@ Goals Installation ------------ - -It's best to use pip to install the package from PyPI, e.g.: -``` -pip install dinglehopper -``` +It's best to use pip, e.g.: +~~~ +sudo pip install . +~~~ Usage ----- @@ -70,19 +69,19 @@ This generates `report.html` and `report.json`. ![dinglehopper displaying metrics and character differences](.screenshots/dinglehopper.png?raw=true) -Batch comparison between folders of GT and OCR files can be done by simply providing +Batch comparison between folders of GT and OCR files can be done by simply providing folders: ~~~ dinglehopper gt/ ocr/ report output_folder/ ~~~ -This assumes that you have files with the same name in both folders, e.g. +This assumes that you have files with the same name in both folders, e.g. `gt/00000001.page.xml` and `ocr/00000001.alto.xml`. -The example generates reports for each set of files, with the prefix `report`, in the +The example generates reports for each set of files, with the prefix `report`, in the (automatically created) folder `output_folder/`. -By default, the JSON report does not contain the character and word differences, only -the calculated metrics. If you want to include the differences, use the +By default, the JSON report does not contain the character and word differences, only +the calculated metrics. If you want to include the differences, use the `--differences` flag: ~~~ @@ -90,7 +89,7 @@ dinglehopper gt/ ocr/ report output_folder/ --differences ~~~ ### dinglehopper-summarize -A set of (JSON) reports can be summarized into a single set of +A set of (JSON) reports can be summarized into a single set of reports. This is useful after having generated reports in batch. Example: ~~~ @@ -100,11 +99,11 @@ This generates `summary.html` and `summary.json` in the same `output_folder`. If you are summarizing many reports and have used the `--differences` flag while generating them, it may be useful to limit the number of differences reported by using -the `--occurrences-threshold` parameter. This will reduce the size of the generated HTML +the `--occurences-threshold` parameter. This will reduce the size of the generated HTML report, making it easier to open and navigate. Note that the JSON report will still contain all differences. Example: ~~~ -dinglehopper-summarize output_folder/ --occurrences-threshold 10 +dinglehopper-summarize output_folder/ --occurences-threshold 10 ~~~ ### dinglehopper-line-dirs @@ -112,13 +111,9 @@ You also may want to compare a directory of GT text files (i.e. `gt/line0001.gt. with a directory of OCR text files (i.e. `ocr/line0001.some-ocr.txt`) with a separate CLI interface: -``` +~~~ dinglehopper-line-dirs gt/ ocr/ -``` - -The CLI `dinglehopper-line-dirs` can also work with GT text files in the same -directories as the the OCR text files. You should read `dinglehopper-line-dirs --help` -in this case. +~~~ ### dinglehopper-extract The tool `dinglehopper-extract` extracts the text of the given input file on diff --git a/pyproject.toml b/pyproject.toml index 62fae82..2e98ae1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,9 @@ authors = [ {name = "Mike Gerber", email = "mike.gerber@sbb.spk-berlin.de"}, {name = "The QURATOR SPK Team", email = "qurator@sbb.spk-berlin.de"}, ] -description = "An OCR evaluation tool" +description = "The OCR evaluation tool" readme = "README.md" -license.file = "LICENSE" -requires-python = ">=3.8" +requires-python = ">=3.6" keywords = ["qurator", "ocr", "evaluation", "ocr-d"] dynamic = ["version", "dependencies", "optional-dependencies"] @@ -49,7 +48,7 @@ optional-dependencies.dev = {file = ["requirements-dev.txt"]} where = ["src"] [tool.setuptools.package-data] -dinglehopper = ["templates/*", "*.json"] +dinglehopper = ["*.json", "templates/*"] [tool.pytest.ini_options] @@ -61,54 +60,11 @@ markers = [ [tool.mypy] -plugins = ["numpy.typing.mypy_plugin"] - ignore_missing_imports = true -strict = true - -disallow_subclassing_any = false -# ❗ error: Class cannot subclass "Processor" (has type "Any") -disallow_any_generics = false -disallow_untyped_defs = false -disallow_untyped_calls = false - - -[tool.ruff.lint] +[tool.ruff] select = ["E", "F", "I"] - - -[tool.liccheck] -authorized_licenses = [ - "bsd", - "new bsd", - "bsd license", - "new bsd license", - "simplified bsd", - "apache", - "apache 2.0", - "apache software license", - "apache software", - "apache license 2.0", - "gnu lgpl", - "lgpl with exceptions or zpl", - "GNU Library or Lesser General Public License (LGPL)", - "GNU Lesser General Public License v3 (LGPLv3)", - "GNU Lesser General Public License v2 or later (LGPLv2+)", - "mit", - "mit license", - "mit-cmu", - "python software foundation", - "psf", - "psf-2.0", - "Historical Permission Notice and Disclaimer (HPND)", - "public domain", - 'The Unlicense (Unlicense)', - "isc", - "ISC License (ISCL)", - 'Mozilla Public License 2.0 (MPL 2.0)', -] -unauthorized_licenses = [ - "gpl v3", +ignore = [ + "F811", # multimethods are considered redefinitions by ruff ] diff --git a/requirements-dev.txt b/requirements-dev.txt index f9f748a..4bf395e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,14 +1,8 @@ pytest pytest-cov +pytest-mypy black pre-commit -ruff -pytest-ruff - -mypy -types-lxml -types-setuptools -pytest-mypy - -liccheck +ruff ; python_version >= "3.7" +pytest-ruff ; python_version >= "3.7" diff --git a/requirements.txt b/requirements.txt index 653ec59..8ee3d1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,14 @@ click jinja2 lxml -uniseg >= 0.8.0 +uniseg numpy colorama MarkupSafe -ocrd >= 3.3.0 +ocrd >= 2.20.1 attrs -multimethod >= 1.3 +multimethod == 1.3 # latest version to officially support Python 3.5 tqdm -rapidfuzz >= 2.7.0 +rapidfuzz >= 2.4.2 +six # XXX workaround OCR-D/core#730 chardet -importlib_resources diff --git a/src/dinglehopper/tests/data/line_dirs/merged/a/a.dummy.jpg b/setup.cfg similarity index 100% rename from src/dinglehopper/tests/data/line_dirs/merged/a/a.dummy.jpg rename to setup.cfg diff --git a/src/dinglehopper/__init__.py b/src/dinglehopper/__init__.py index 2e79b69..0f6ab60 100644 --- a/src/dinglehopper/__init__.py +++ b/src/dinglehopper/__init__.py @@ -1,4 +1,4 @@ -from .align import align, score_hint, seq_align +from .align import align, seq_align from .character_error_rate import character_error_rate, character_error_rate_n from .edit_distance import distance, editops from .extracted_text import ExtractedText @@ -16,7 +16,6 @@ __all__ = [ "editops", "distance", "align", - "score_hint", "seq_align", "character_error_rate", "character_error_rate_n", diff --git a/src/dinglehopper/align.py b/src/dinglehopper/align.py index 5d1f290..988ec9a 100644 --- a/src/dinglehopper/align.py +++ b/src/dinglehopper/align.py @@ -1,10 +1,8 @@ -import math import unicodedata -from math import ceil -from typing import Optional from rapidfuzz.distance import Levenshtein -from uniseg.graphemecluster import grapheme_clusters + +from .edit_distance import grapheme_clusters def align(t1, t2): @@ -14,27 +12,11 @@ def align(t1, t2): return seq_align(s1, s2) -def score_hint(er: float, n: int) -> Optional[int]: - """Calculate RapidFuzz score hint for a given error rate and count. - - Gives the score hint for the distance functions (= expected distance) or None if - the error rate is inf. - """ - assert not math.isnan(er) - try: - score_hint = int(ceil(er * n)) - except (OverflowError, ValueError): - # ceil(er * n) can be inf or NaN (for n == 0), so int() can throw an - # OverflowError and a ValueError. - score_hint = None - return score_hint - - -def seq_align(s1, s2, score_hint=None): +def seq_align(s1, s2): """Align general sequences.""" s1 = list(s1) s2 = list(s2) - ops = Levenshtein.editops(s1, s2, score_hint=score_hint) + ops = Levenshtein.editops(s1, s2) i = 0 j = 0 diff --git a/src/dinglehopper/character_error_rate.py b/src/dinglehopper/character_error_rate.py index 88a88f8..0c3ef7d 100644 --- a/src/dinglehopper/character_error_rate.py +++ b/src/dinglehopper/character_error_rate.py @@ -1,5 +1,7 @@ +from __future__ import division + import unicodedata -from typing import List, Tuple, TypeVar +from typing import Tuple from multimethod import multimethod from uniseg.graphemecluster import grapheme_clusters @@ -7,13 +9,9 @@ from uniseg.graphemecluster import grapheme_clusters from .edit_distance import distance from .extracted_text import ExtractedText -T = TypeVar("T") - @multimethod -def character_error_rate_n( - reference: List[str], compared: List[str] -) -> Tuple[float, int]: +def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: """ Compute character error rate. @@ -21,7 +19,7 @@ def character_error_rate_n( """ d = distance(reference, compared) - n = len(reference) + n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference)))) if d == 0: return 0, n @@ -32,28 +30,18 @@ def character_error_rate_n( # XXX Should we really count newlines here? -@character_error_rate_n.register -def _(reference: str, compared: str) -> Tuple[float, int]: - seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference))) - seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared))) - cer, n = character_error_rate_n(seq1, seq2) - return cer, n +@multimethod +def character_error_rate_n( + reference: ExtractedText, compared: ExtractedText +) -> Tuple[float, int]: + return character_error_rate_n(reference.text, compared.text) -@character_error_rate_n.register -def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: - cer, n = character_error_rate_n( - reference.grapheme_clusters, compared.grapheme_clusters - ) - return cer, n - - -def character_error_rate(reference: T, compared: T) -> float: +def character_error_rate(reference, compared) -> float: """ Compute character error rate. :return: character error rate """ - cer: float cer, _ = character_error_rate_n(reference, compared) return cer diff --git a/src/dinglehopper/cli.py b/src/dinglehopper/cli.py index 2d3c075..3f3c835 100644 --- a/src/dinglehopper/cli.py +++ b/src/dinglehopper/cli.py @@ -1,13 +1,13 @@ import os from collections import Counter -from typing import List import click from jinja2 import Environment, FileSystemLoader from markupsafe import escape from ocrd_utils import initLogging +from uniseg.graphemecluster import grapheme_clusters -from dinglehopper.align import score_hint, seq_align +from dinglehopper.align import seq_align from dinglehopper.character_error_rate import character_error_rate_n from dinglehopper.config import Config from dinglehopper.extracted_text import ExtractedText @@ -15,9 +15,7 @@ from dinglehopper.ocr_files import extract from dinglehopper.word_error_rate import word_error_rate_n, words_normalized -def gen_diff_report( - gt_in, ocr_in, css_prefix, joiner, none, *, differences=False, score_hint=None -): +def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, differences=False): gtx = "" ocrx = "" @@ -44,8 +42,9 @@ def gen_diff_report( if isinstance(gt_in, ExtractedText): if not isinstance(ocr_in, ExtractedText): raise TypeError() - gt_things = gt_in.grapheme_clusters - ocr_things = ocr_in.grapheme_clusters + # XXX splitting should be done in ExtractedText + gt_things = list(grapheme_clusters(gt_in.text)) + ocr_things = list(grapheme_clusters(ocr_in.text)) else: gt_things = gt_in ocr_things = ocr_in @@ -54,7 +53,7 @@ def gen_diff_report( o_pos = 0 found_differences = [] - for k, (g, o) in enumerate(seq_align(gt_things, ocr_things, score_hint)): + for k, (g, o) in enumerate(seq_align(gt_things, ocr_things)): css_classes = None gt_id = None ocr_id = None @@ -77,7 +76,7 @@ def gen_diff_report( if o is not None: o_pos += len(o) - counted_differences = dict(Counter(elem for elem in found_differences)) + found_differences = dict(Counter(elem for elem in found_differences)) return ( """ @@ -88,7 +87,7 @@ def gen_diff_report( """.format( gtx, ocrx ), - counted_differences, + found_differences, ) @@ -106,56 +105,39 @@ def json_float(value): def process( - gt: str, - ocr: str, - report_prefix: str, - reports_folder: str = ".", + gt, + ocr, + report_prefix, + reports_folder=".", *, - metrics: bool = True, - differences: bool = False, - textequiv_level: str = "region", - plain_encoding: str = "autodetect", -) -> None: + metrics=True, + differences=False, + textequiv_level="region", +): """Check OCR result against GT. The @click decorators change the signature of the decorated functions, so we keep this undecorated version and use Click on a wrapper. """ - gt_text = extract( - gt, textequiv_level=textequiv_level, plain_encoding=plain_encoding - ) - ocr_text = extract( - ocr, textequiv_level=textequiv_level, plain_encoding=plain_encoding - ) - gt_words: List[str] = list(words_normalized(gt_text)) - ocr_words: List[str] = list(words_normalized(ocr_text)) + gt_text = extract(gt, textequiv_level=textequiv_level) + ocr_text = extract(ocr, textequiv_level=textequiv_level) - assert isinstance(gt_text, ExtractedText) - assert isinstance(ocr_text, ExtractedText) cer, n_characters = character_error_rate_n(gt_text, ocr_text) + wer, n_words = word_error_rate_n(gt_text, ocr_text) + char_diff_report, diff_c = gen_diff_report( - gt_text, - ocr_text, - css_prefix="c", - joiner="", - none="·", - score_hint=score_hint(cer, n_characters), - differences=differences, + gt_text, ocr_text, css_prefix="c", joiner="", none="·", differences=differences ) - # {gt,ocr}_words must not be a generator, so we don't drain it for the differences - # report. - assert isinstance(gt_words, list) - assert isinstance(ocr_words, list) - wer, n_words = word_error_rate_n(gt_words, ocr_words) + gt_words = words_normalized(gt_text) + ocr_words = words_normalized(ocr_text) word_diff_report, diff_w = gen_diff_report( gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯", - score_hint=score_hint(wer, n_words), differences=differences, ) @@ -192,16 +174,8 @@ def process( def process_dir( - gt: str, - ocr: str, - report_prefix: str, - reports_folder: str = ".", - *, - metrics: bool = True, - differences: bool = False, - textequiv_level: str = "region", - plain_encoding: str = "autodetect", -) -> None: + gt, ocr, report_prefix, reports_folder, metrics, differences, textequiv_level +): for gt_file in os.listdir(gt): gt_file_path = os.path.join(gt, gt_file) ocr_file_path = os.path.join(ocr, gt_file) @@ -215,7 +189,6 @@ def process_dir( metrics=metrics, differences=differences, textequiv_level=textequiv_level, - plain_encoding=plain_encoding, ) else: print("Skipping {0} and {1}".format(gt_file_path, ocr_file_path)) @@ -240,13 +213,7 @@ def process_dir( help="PAGE TextEquiv level to extract text from", metavar="LEVEL", ) -@click.option( - "--plain-encoding", - default="autodetect", - help='Encoding (e.g. "utf-8") of plain text files', -) @click.option("--progress", default=False, is_flag=True, help="Show progress bar") -@click.version_option() def main( gt, ocr, @@ -255,7 +222,6 @@ def main( metrics, differences, textequiv_level, - plain_encoding, progress, ): """ @@ -290,10 +256,9 @@ def main( ocr, report_prefix, reports_folder, - metrics=metrics, - differences=differences, - textequiv_level=textequiv_level, - plain_encoding=plain_encoding, + metrics, + differences, + textequiv_level, ) else: process( @@ -304,7 +269,6 @@ def main( metrics=metrics, differences=differences, textequiv_level=textequiv_level, - plain_encoding=plain_encoding, ) diff --git a/src/dinglehopper/cli_extract.py b/src/dinglehopper/cli_extract.py index 5fce032..9c51d34 100644 --- a/src/dinglehopper/cli_extract.py +++ b/src/dinglehopper/cli_extract.py @@ -12,12 +12,7 @@ from .ocr_files import extract help="PAGE TextEquiv level to extract text from", metavar="LEVEL", ) -@click.option( - "--plain-encoding", - default="autodetect", - help='Encoding (e.g. "utf-8") of plain text files', -) -def main(input_file, textequiv_level, plain_encoding): +def main(input_file, textequiv_level): """ Extract the text of the given INPUT_FILE. @@ -28,9 +23,7 @@ def main(input_file, textequiv_level, plain_encoding): use "--textequiv-level line" to extract from the level of TextLine tags. """ initLogging() - input_text = extract( - input_file, textequiv_level=textequiv_level, plain_encoding=plain_encoding - ).text + input_text = extract(input_file, textequiv_level=textequiv_level).text print(input_text) diff --git a/src/dinglehopper/cli_line_dirs.py b/src/dinglehopper/cli_line_dirs.py index 0160f87..5fc3754 100644 --- a/src/dinglehopper/cli_line_dirs.py +++ b/src/dinglehopper/cli_line_dirs.py @@ -1,53 +1,16 @@ import itertools import os -from typing import Callable, Iterator, List, Optional, Tuple import click from jinja2 import Environment, FileSystemLoader from ocrd_utils import initLogging -from .align import score_hint from .character_error_rate import character_error_rate_n from .cli import gen_diff_report, json_float from .ocr_files import plain_extract from .word_error_rate import word_error_rate_n, words_normalized -def removesuffix(text, suffix): - """ - Remove suffix from text. - - Can be replaced with str.removesuffix when we only support Python >= 3.9. - """ - if suffix and text.endswith(suffix): - return text[: -len(suffix)] - return text - - -def is_hidden(filepath): - filename = os.path.basename(os.path.abspath(filepath)) - return filename.startswith(".") - - -def find_all_files( - dir_: str, pred: Optional[Callable[[str], bool]] = None, return_hidden: bool = False -) -> Iterator[str]: - """ - Find all files in dir_, returning filenames - - If pred is given, pred(filename) must be True for the filename. - - Does not return hidden files by default. - """ - for root, _, filenames in os.walk(dir_): - for fn in filenames: - if not return_hidden and is_hidden(fn): - continue - if pred and not pred(fn): - continue - yield os.path.join(root, fn) - - def all_equal(iterable): g = itertools.groupby(iterable) return next(g, True) and not next(g, False) @@ -61,63 +24,15 @@ def common_suffix(its): return reversed(common_prefix(reversed(it) for it in its)) -def find_gt_and_ocr_files( - gt_dir: str, gt_suffix: str, ocr_dir: str, ocr_suffix: str -) -> Iterator[Tuple[str, str]]: - """ - Find GT files and matching OCR files. - - Returns pairs of GT and OCR files. - """ - for gt_fn in find_all_files(gt_dir, lambda fn: fn.endswith(gt_suffix)): - ocr_fn = os.path.join( - ocr_dir, - removesuffix(os.path.relpath(gt_fn, start=gt_dir), gt_suffix) + ocr_suffix, - ) - if not os.path.exists(ocr_fn): - raise RuntimeError(f"{ocr_fn} (matching {gt_fn}) does not exist") - - yield gt_fn, ocr_fn +def removesuffix(text, suffix): + if suffix and text.endswith(suffix): + return text[: -len(suffix)] + return text -def find_gt_and_ocr_files_autodetect(gt_dir, ocr_dir): - """ - Find GT files and matching OCR files, autodetect suffixes. - - This only works if gt_dir (or respectivley ocr_dir) only contains GT (OCR) - files with a common suffix. Currently the files must have a suffix, e.g. - ".gt.txt" (e.g. ".ocr.txt"). - - Returns pairs of GT and OCR files. - """ - - # Autodetect suffixes - gt_files = find_all_files(gt_dir) - gt_suffix = "".join(common_suffix(gt_files)) - if len(gt_suffix) == 0: - raise RuntimeError( - f"Files in GT directory {gt_dir} do not have a common suffix" - ) - ocr_files = find_all_files(ocr_dir) - ocr_suffix = "".join(common_suffix(ocr_files)) - if len(ocr_suffix) == 0: - raise RuntimeError( - f"Files in OCR directory {ocr_dir} do not have a common suffix" - ) - - yield from find_gt_and_ocr_files(gt_dir, gt_suffix, ocr_dir, ocr_suffix) - - -def process( - gt_dir, - ocr_dir, - report_prefix, - *, - metrics=True, - gt_suffix=None, - ocr_suffix=None, - plain_encoding="autodetect", -): +def process(gt_dir, ocr_dir, report_prefix, *, metrics=True): + gt_suffix = "".join(common_suffix(os.listdir(gt_dir))) + ocr_suffix = "".join(common_suffix(os.listdir(ocr_dir))) cer = None n_characters = None @@ -126,20 +41,14 @@ def process( n_words = None word_diff_report = "" - if gt_suffix is not None and ocr_suffix is not None: - gt_ocr_files = find_gt_and_ocr_files(gt_dir, gt_suffix, ocr_dir, ocr_suffix) - else: - gt_ocr_files = find_gt_and_ocr_files_autodetect(gt_dir, ocr_dir) + for k, gt in enumerate(os.listdir(gt_dir)): + # Find a match by replacing the suffix + ocr = removesuffix(gt, gt_suffix) + ocr_suffix - for k, (gt_fn, ocr_fn) in enumerate(gt_ocr_files): - gt_text = plain_extract( - gt_fn, include_filename_in_id=True, encoding=plain_encoding - ) + gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True) ocr_text = plain_extract( - ocr_fn, include_filename_in_id=True, encoding=plain_encoding + os.path.join(ocr_dir, ocr), include_filename_in_id=True ) - gt_words: List[str] = list(words_normalized(gt_text)) - ocr_words: List[str] = list(words_normalized(ocr_text)) # Compute CER l_cer, l_n_characters = character_error_rate_n(gt_text, ocr_text) @@ -153,7 +62,7 @@ def process( n_characters = n_characters + l_n_characters # Compute WER - l_wer, l_n_words = word_error_rate_n(gt_words, ocr_words) + l_wer, l_n_words = word_error_rate_n(gt_text, ocr_text) if wer is None: wer, n_words = l_wer, l_n_words else: @@ -163,21 +72,13 @@ def process( # Generate diff reports char_diff_report += gen_diff_report( - gt_text, - ocr_text, - css_prefix="l{0}-c".format(k), - joiner="", - none="·", - score_hint=score_hint(l_cer, l_n_characters), - )[0] + gt_text, ocr_text, css_prefix="l{0}-c".format(k), joiner="", none="·" + ) + gt_words = words_normalized(gt_text) + ocr_words = words_normalized(ocr_text) word_diff_report += gen_diff_report( - gt_words, - ocr_words, - css_prefix="l{0}-w".format(k), - joiner=" ", - none="⋯", - score_hint=score_hint(l_wer, l_n_words), - )[0] + gt_words, ocr_words, css_prefix="l{0}-w".format(k), joiner=" ", none="⋯" + ) env = Environment( loader=FileSystemLoader( @@ -211,30 +112,17 @@ def process( @click.option( "--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red" ) -@click.option("--gt-suffix", help="Suffix of GT line text files") -@click.option("--ocr-suffix", help="Suffix of OCR line text files") -@click.option( - "--plain-encoding", - default="autodetect", - help='Encoding (e.g. "utf-8") of plain text files', -) -def main(gt, ocr, report_prefix, metrics, gt_suffix, ocr_suffix, plain_encoding): +def main(gt, ocr, report_prefix, metrics): """ Compare the GT line text directory against the OCR line text directory. This assumes that the GT line text directory contains textfiles with a common suffix like ".gt.txt", and the OCR line text directory contains textfiles with a common suffix like ".some-ocr.txt". The text files also need to be paired, - i.e. the GT filename "line001.gt.txt" needs to match a filename - "line001.some-ocr.txt" in the OCR lines directory. + i.e. the GT file "line001.gt.txt" needs to match a file "line001.some-ocr.txt" + in the OCT lines directory. - GT and OCR directories may contain line text files in matching subdirectories, - e.g. "GT/goethe_faust/line1.gt.txt" and "OCR/goethe_faust/line1.pred.txt". - - GT and OCR directories can also be the same directory, but in this case you need - to give --gt-suffix and --ocr-suffix explicitly. - - The GT and OCR directories are usually ground truth line texts and the results of + The GT and OCR directories are usually round truth line texts and the results of an OCR software, but you may use dinglehopper to compare two OCR results. In that case, use --no-metrics to disable the then meaningless metrics and also change the color scheme from green/red to blue. @@ -243,19 +131,9 @@ def main(gt, ocr, report_prefix, metrics, gt_suffix, ocr_suffix, plain_encoding) $REPORT_PREFIX defaults to "report". The reports include the character error rate (CER) and the word error rate (WER). - It is recommended to specify the encoding of the text files, for example with - --plain-encoding utf-8. If this option is not given, we try to auto-detect it. """ initLogging() - process( - gt, - ocr, - report_prefix, - metrics=metrics, - gt_suffix=gt_suffix, - ocr_suffix=ocr_suffix, - plain_encoding=plain_encoding, - ) + process(gt, ocr, report_prefix, metrics=metrics) if __name__ == "__main__": diff --git a/src/dinglehopper/cli_summarize.py b/src/dinglehopper/cli_summarize.py index c49911b..0422759 100644 --- a/src/dinglehopper/cli_summarize.py +++ b/src/dinglehopper/cli_summarize.py @@ -1,6 +1,5 @@ import json import os -from typing import Dict import click from jinja2 import Environment, FileSystemLoader @@ -14,8 +13,8 @@ def process(reports_folder, occurrences_threshold=1): wer_list = [] cer_sum = 0 wer_sum = 0 - diff_c: Dict[str, int] = {} - diff_w: Dict[str, int] = {} + diff_c = {} + diff_w = {} for report in os.listdir(reports_folder): if report.endswith(".json"): @@ -35,15 +34,10 @@ def process(reports_folder, occurrences_threshold=1): cer_sum += cer wer_sum += wer - try: - for key, value in report_data["differences"][ - "character_level" - ].items(): - diff_c[key] = diff_c.get(key, 0) + value - for key, value in report_data["differences"]["word_level"].items(): - diff_w[key] = diff_w.get(key, 0) + value - except KeyError: - pass + for key, value in report_data["differences"]["character_level"].items(): + diff_c[key] = diff_c.get(key, 0) + value + for key, value in report_data["differences"]["word_level"].items(): + diff_w[key] = diff_w.get(key, 0) + value if len(cer_list) == 0: click.echo(f"No reports found in folder '{os.path.abspath(reports_folder)}'") diff --git a/src/dinglehopper/edit_distance.py b/src/dinglehopper/edit_distance.py index ec564ae..e5194bf 100644 --- a/src/dinglehopper/edit_distance.py +++ b/src/dinglehopper/edit_distance.py @@ -1,5 +1,6 @@ +from __future__ import division, print_function + import unicodedata -from typing import List from multimethod import multimethod from rapidfuzz.distance import Levenshtein @@ -9,18 +10,7 @@ from .extracted_text import ExtractedText @multimethod -def distance(seq1: List[str], seq2: List[str]) -> int: - """Compute the Levenshtein edit distance between two lists of grapheme clusters. - - This assumes that the grapheme clusters are already normalized. - - Use distance(str, str) instead if you need to compare two Unicode strings. - """ - return Levenshtein.distance(seq1, seq2) - - -@distance.register -def _(s1: str, s2: str) -> int: +def distance(s1: str, s2: str): """Compute the Levenshtein edit distance between two Unicode strings Note that this is different from levenshtein() as this function knows about Unicode @@ -32,9 +22,9 @@ def _(s1: str, s2: str) -> int: return Levenshtein.distance(seq1, seq2) -@distance.register -def _(s1: ExtractedText, s2: ExtractedText) -> int: - return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters) +@multimethod +def distance(s1: ExtractedText, s2: ExtractedText): + return distance(s1.text, s2.text) def editops(word1, word2): diff --git a/src/dinglehopper/extracted_text.py b/src/dinglehopper/extracted_text.py index acfbf78..9703b6b 100644 --- a/src/dinglehopper/extracted_text.py +++ b/src/dinglehopper/extracted_text.py @@ -1,16 +1,14 @@ import enum -import functools import re import unicodedata from contextlib import suppress from itertools import repeat -from typing import Any, Dict, List, Optional +from typing import Optional import attr import numpy as np from lxml import etree as ET from ocrd_utils import getLogger -from uniseg.graphemecluster import grapheme_clusters class Normalization(enum.Enum): @@ -122,7 +120,7 @@ class ExtractedText: segment_id = attr.ib(type=Optional[str]) @segment_id.validator - def is_valid_segment_id(self, _, value): + def check(self, _, value): if value is None: return if not re.match(r"[\w\d_-]+", value): @@ -132,85 +130,33 @@ class ExtractedText: # a. _text itself # b. or segments (ExtractedText) and a joiner - segments = attr.ib(type=Optional[List["ExtractedText"]]) + segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list)) joiner = attr.ib(type=Optional[str]) _text = attr.ib(type=Optional[str]) - _grapheme_clusters = attr.ib(type=Optional[List[str]]) @segments.validator - def cant_set_both_segments_and_text(self, _, value): + def check(self, _, value): if value is not None and self._text is not None: raise ValueError("Can't have both segments and text") - @joiner.validator - def is_valid_joiner(self, _, value): - if self.segments is None: - if value is not None: - raise ValueError("Can't have joiner without segments to join") - if self.segments is not None: - if value not in ("", " ", "\n"): - raise ValueError(f"Unexpected segment joiner value {repr(value)}") - @_text.validator - def is_valid_text(self, _, value): - if value is None: - return - - if self.segments is not None: + def check(self, _, value): + if value is not None and self.segments is not None: raise ValueError("Can't have both segments and text") - if unicodedata.normalize("NFC", value) != value: + if value is not None and unicodedata.normalize("NFC", value) != value: raise ValueError('String "{}" is not in NFC.'.format(value)) - if normalize(value, self.normalization) != value: + if value is not None and normalize(value, self.normalization) != value: raise ValueError('String "{}" is not normalized.'.format(value)) - if self._grapheme_clusters is None: - raise ValueError("Requires both text and grapheme clusters to be set") - - @_grapheme_clusters.validator - def are_valid_grapheme_clusters(self, _, value): - if value is not None and self._text is None: - raise ValueError("Requires both text and grapheme clusters to be set") normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB) @property - def text(self) -> str: + def text(self): if self._text is not None: return self._text else: - assert self.joiner is not None and self.segments is not None return self.joiner.join(s.text for s in self.segments) - @functools.cached_property - def _joiner_grapheme_cluster(self): - """We need the joiner as a list of 0 or 1 grapheme clusters. - - This property is cached. - """ - - assert self.joiner is not None - if len(self.joiner) > 0: - joiner_grapheme_cluster = list(grapheme_clusters(self.joiner)) - assert len(joiner_grapheme_cluster) == 1 # see joiner's check above - elif len(self.joiner) == 0: - joiner_grapheme_cluster = [] - else: - joiner_grapheme_cluster = None - - return joiner_grapheme_cluster - - @property - def grapheme_clusters(self): - if self._text is not None: - return self._grapheme_clusters - else: - # TODO Test with text extracted at glyph level (joiner == "") - clusters = [] - assert self.segments is not None - for seg in self.segments: - clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster - clusters = clusters[:-1] - return clusters - _segment_id_for_pos = None def segment_id_for_pos(self, pos): @@ -221,7 +167,6 @@ class ExtractedText: else: # Recurse segment_id_for_pos = [] - assert self.joiner is not None and self.segments is not None for s in self.segments: seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))] segment_id_for_pos.extend(seg_ids) @@ -235,7 +180,7 @@ class ExtractedText: return self._segment_id_for_pos[pos] @classmethod - def from_text_segment(cls, text_segment, nsmap, *, textequiv_level="region"): + def from_text_segment(cls, text_segment, nsmap, textequiv_level="region"): """Build an ExtractedText from a PAGE content text element""" localname_for_textequiv_level = {"region": "TextRegion", "line": "TextLine"} @@ -252,8 +197,7 @@ class ExtractedText: # FIXME hardcoded SBB normalization segment_text = normalize_sbb(segment_text) segment_text = segment_text or "" - clusters = list(grapheme_clusters(segment_text)) - return cls(segment_id, None, None, segment_text, clusters) + return cls(segment_id, None, None, segment_text) else: # Recurse sub_localname = children_for_localname[localname] @@ -268,15 +212,12 @@ class ExtractedText: ) ) joiner = joiner_for_textequiv_level[sub_textequiv_level] - return cls(segment_id, segments, joiner, None, None) + return cls(segment_id, segments, joiner, None) @classmethod def from_str(cls, text, normalization=Normalization.NFC_SBB): normalized_text = normalize(text, normalization) - clusters = list(grapheme_clusters(normalized_text)) - return cls( - None, None, None, normalized_text, clusters, normalization=normalization - ) + return cls(None, None, None, normalized_text, normalization=normalization) def invert_dict(d): @@ -284,7 +225,7 @@ def invert_dict(d): return {v: k for k, v in d.items()} -def get_textequiv_unicode(text_segment: Any, nsmap: Dict[str, str]) -> str: +def get_textequiv_unicode(text_segment, nsmap) -> str: """Get the TextEquiv/Unicode text of the given PAGE text element.""" segment_id = text_segment.attrib["id"] textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap) @@ -308,7 +249,7 @@ def get_first_textequiv(textequivs, segment_id): if np.any(~nan_mask): if np.any(nan_mask): log.warning("TextEquiv without index in %s.", segment_id) - index = int(np.nanargmin(indices)) + index = np.nanargmin(indices) else: # try ordering by conf confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float) @@ -317,7 +258,7 @@ def get_first_textequiv(textequivs, segment_id): "No index attributes, use 'conf' attribute to sort TextEquiv in %s.", segment_id, ) - index = int(np.nanargmax(confidences)) + index = np.nanargmax(confidences) else: # fallback to first entry in case of neither index or conf present log.warning("No index attributes, use first TextEquiv in %s.", segment_id) @@ -325,11 +266,11 @@ def get_first_textequiv(textequivs, segment_id): return textequivs[index] -def get_attr(te: Any, attr_name: str) -> float: +def get_attr(te, attr_name) -> float: """Extract the attribute for the given name. Note: currently only handles numeric values! - Other or non existent values are encoded as np.nan. + Other or non existend values are encoded as np.nan. """ attr_value = te.attrib.get(attr_name) try: diff --git a/src/dinglehopper/notebooks/Levenshtein.ipynb b/src/dinglehopper/notebooks/Levenshtein.ipynb index b9671d7..a27dca4 100644 --- a/src/dinglehopper/notebooks/Levenshtein.ipynb +++ b/src/dinglehopper/notebooks/Levenshtein.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "dinglehopper used to have its own (very inefficient) Levenshtein edit distance implementation, but now uses RapidFuzz." + "dinglehopper uses to have its own (very inefficient) Levenshtein edit distance implementation, but now uses RapidFuzz." ] }, { @@ -391,7 +391,7 @@ "\\text{CER} = \\frac{i + s + d}{n}\n", "$$\n", "\n", - "where $i$ is the number of inserts, $s$ the number of substitutions, $d$ the number of deletions and $n$ is the number of characters in the reference text. (The text is not super clear about $n$ being the number of characters in the reference text, but it seems appropriate as they *are* clear about this when computing the word error rate.)" + "where $i$ is the number of inserts, $s$ the number of substitutions, $d$ the number of deletions and $n$ is the number of characters in the reference text. (The text is not super clear about $n$ being the number of characters in the reference text, but it seems appropiate as they *are* clear about this when computing the word error rate.)" ] }, { @@ -680,7 +680,7 @@ " return cat in unwanted_categories or subcat in unwanted_subcategories\n", "\n", " # We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on word boundaries using\n", - " # uniseg.wordbreak.words() and ignore all \"words\" that contain only whitespace, punctuation \"or similar characters.\"\n", + " # uniseg.wordbreak.words() and ignore all \"words\" that contain only whitespace, punctation \"or similar characters.\"\n", " for word in uniseg.wordbreak.words(s):\n", " if all(unwanted(c) for c in word):\n", " pass\n", diff --git a/src/dinglehopper/ocr_files.py b/src/dinglehopper/ocr_files.py index fdcaf54..5c4339b 100644 --- a/src/dinglehopper/ocr_files.py +++ b/src/dinglehopper/ocr_files.py @@ -1,56 +1,44 @@ +from __future__ import division, print_function + import os import sys -from typing import Dict, Iterator, Optional +from typing import Iterator import chardet from lxml import etree as ET from lxml.etree import XMLSyntaxError -from ocrd_utils import getLogger -from uniseg.graphemecluster import grapheme_clusters from .extracted_text import ExtractedText, normalize_sbb -log = getLogger("processor.OcrdDinglehopperEvaluate") - -def alto_namespace(tree: ET._ElementTree) -> Optional[str]: +def alto_namespace(tree: ET.ElementTree) -> str: """Return the ALTO namespace used in the given ElementTree. This relies on the assumption that, in any given ALTO file, the root element has the - local name "alto". We do not check if the file uses any valid ALTO namespace. + local name "alto". We do not check if the files uses any valid ALTO namespace. """ root_name = ET.QName(tree.getroot().tag) if root_name.localname == "alto": - assert isinstance(root_name.namespace, str) return root_name.namespace else: raise ValueError("Not an ALTO tree") -def alto_nsmap(tree: ET._ElementTree) -> Dict[str, str]: - alto_ns = alto_namespace(tree) - if alto_ns is None: - raise ValueError("Could not determine ALTO namespace") - return {"alto": alto_ns} - - -def alto_extract_lines(tree: ET._ElementTree) -> Iterator[ExtractedText]: - nsmap = alto_nsmap(tree) +def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]: + nsmap = {"alto": alto_namespace(tree)} for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap): line_id = line.attrib.get("ID") line_text = " ".join( - string.attrib.get("CONTENT", "") + string.attrib.get("CONTENT") for string in line.iterfind("alto:String", namespaces=nsmap) ) - normalized_text = normalize_sbb(line_text) - clusters = list(grapheme_clusters(normalized_text)) - yield ExtractedText(line_id, None, None, normalized_text, clusters) + yield ExtractedText(line_id, None, None, normalize_sbb(line_text)) # FIXME hardcoded SBB normalization -def alto_extract(tree: ET._ElementTree) -> ExtractedText: +def alto_extract(tree: ET.ElementTree) -> ExtractedText: """Extract text from the given ALTO ElementTree.""" - return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None) + return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None) def alto_text(tree): @@ -99,7 +87,7 @@ def page_extract(tree, *, textequiv_level="region"): # Filter empty region texts regions = [r for r in regions if r.text != ""] - return ExtractedText(None, regions, "\n", None, None) + return ExtractedText(None, regions, "\n", None) def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level): @@ -109,7 +97,7 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level): if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]: ro_children = list(group) - ro_children = [child for child in ro_children if "index" in child.attrib.keys()] + ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children) ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"])) elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]: ro_children = list(group) @@ -152,44 +140,33 @@ def detect_encoding(filename): return chardet.detect(open(filename, "rb").read(1024))["encoding"] -def plain_extract(filename, include_filename_in_id=False, encoding="autodetect"): +def plain_extract(filename, include_filename_in_id=False): id_template = "{filename} - line {no}" if include_filename_in_id else "line {no}" - def make_segment(no, line): - normalized_text = normalize_sbb(line) - clusters = list(grapheme_clusters(normalized_text)) - return ExtractedText( - id_template.format(filename=os.path.basename(filename), no=no), - None, - None, - normalized_text, - clusters, - ) - - if encoding == "autodetect": - fileencoding = detect_encoding(filename) - log.warning( - f"Autodetected encoding as '{fileencoding}'" - ", it is recommended to specify it explicitly with --plain-encoding" - ) - else: - fileencoding = encoding + fileencoding = detect_encoding(filename) with open(filename, "r", encoding=fileencoding) as f: return ExtractedText( None, - [make_segment(no, line.strip()) for no, line in enumerate(f.readlines())], + [ + ExtractedText( + id_template.format(filename=os.path.basename(filename), no=no), + None, + None, + normalize_sbb(line), + ) + for no, line in enumerate(f.readlines()) + ], "\n", None, - None, ) # XXX hardcoded SBB normalization -def plain_text(filename, encoding="autodetect"): - return plain_extract(filename, encoding=encoding).text +def plain_text(filename): + return plain_extract(filename).text -def extract(filename, *, textequiv_level="region", plain_encoding="autodetect"): +def extract(filename, *, textequiv_level="region"): """Extract the text from the given file. Supports PAGE, ALTO and falls back to plain text. @@ -197,7 +174,7 @@ def extract(filename, *, textequiv_level="region", plain_encoding="autodetect"): try: tree = ET.parse(filename) except (XMLSyntaxError, UnicodeDecodeError): - return plain_extract(filename, encoding=plain_encoding) + return plain_extract(filename) try: return page_extract(tree, textequiv_level=textequiv_level) except ValueError: diff --git a/src/dinglehopper/ocrd-tool.json b/src/dinglehopper/ocrd-tool.json index ad48e51..c4f8c4e 100644 --- a/src/dinglehopper/ocrd-tool.json +++ b/src/dinglehopper/ocrd-tool.json @@ -1,13 +1,17 @@ { - "version": "0.11.0", + "version": "0.9.1", "git_url": "https://github.com/qurator-spk/dinglehopper", - "dockerhub": "ocrd/dinglehopper", "tools": { "ocrd-dinglehopper": { "executable": "ocrd-dinglehopper", - "input_file_grp_cardinality": 2, - "output_file_grp_cardinality": 1, "description": "Evaluate OCR text against ground truth with dinglehopper", + "input_file_grp": [ + "OCR-D-GT-PAGE", + "OCR-D-OCR" + ], + "output_file_grp": [ + "OCR-D-OCR-EVAL" + ], "categories": [ "Quality assurance" ], @@ -25,11 +29,6 @@ "enum": ["region", "line"], "default": "region", "description": "PAGE XML hierarchy level to extract the text from" - }, - "plain_encoding": { - "type": "string", - "default": "autodetect", - "description": "Encoding (e.g. \"utf-8\") of plain text files" } } } diff --git a/src/dinglehopper/ocrd_cli.py b/src/dinglehopper/ocrd_cli.py index 2d7da8e..8eebdc0 100644 --- a/src/dinglehopper/ocrd_cli.py +++ b/src/dinglehopper/ocrd_cli.py @@ -1,78 +1,78 @@ -from functools import cached_property +import json import os -from typing import Optional import click -from ocrd_models import OcrdFileType from ocrd import Processor from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor -from ocrd_utils import make_file_id +from ocrd_utils import assert_file_grp_cardinality, getLogger, make_file_id +from pkg_resources import resource_string from .cli import process as cli_process +OCRD_TOOL = json.loads(resource_string(__name__, "ocrd-tool.json").decode("utf8")) + + @click.command() @ocrd_cli_options def ocrd_dinglehopper(*args, **kwargs): return ocrd_cli_wrap_processor(OcrdDinglehopperEvaluate, *args, **kwargs) + class OcrdDinglehopperEvaluate(Processor): + def __init__(self, *args, **kwargs): + kwargs["ocrd_tool"] = OCRD_TOOL["tools"]["ocrd-dinglehopper"] + super(OcrdDinglehopperEvaluate, self).__init__(*args, **kwargs) - @cached_property - def executable(self): - return 'ocrd-dinglehopper' + def process(self): + assert_file_grp_cardinality(self.input_file_grp, 2, "GT and OCR") + assert_file_grp_cardinality(self.output_file_grp, 1) - def process_page_file(self, *input_files: Optional[OcrdFileType]) -> None: + log = getLogger("processor.OcrdDinglehopperEvaluate") - assert self.parameter metrics = self.parameter["metrics"] textequiv_level = self.parameter["textequiv_level"] - plain_encoding = self.parameter["plain_encoding"] + gt_grp, ocr_grp = self.input_file_grp.split(",") - # wrong number of inputs: let fail - gt_file, ocr_file = input_files - # missing on either side: skip (zip_input_files already warned) - if not gt_file or not ocr_file: - return - # missing download (i.e. OCRD_DOWNLOAD_INPUT=false): - if not gt_file.local_filename: - if config.OCRD_MISSING_INPUT == 'ABORT': - raise MissingInputFile(gt_file.fileGrp, gt_file.pageId, gt_file.mimetype) - return - if not ocr_file.local_filename: - if config.OCRD_MISSING_INPUT == 'ABORT': - raise MissingInputFile(ocr_file.fileGrp, ocr_file.pageId, ocr_file.mimetype) - return + input_file_tuples = self.zip_input_files(on_error="abort") + for n, (gt_file, ocr_file) in enumerate(input_file_tuples): + if not gt_file or not ocr_file: + # file/page was not found in this group + continue + gt_file = self.workspace.download_file(gt_file) + ocr_file = self.workspace.download_file(ocr_file) + page_id = gt_file.pageId - page_id = gt_file.pageId + log.info("INPUT FILES %i / %s↔ %s", n, gt_file, ocr_file) - file_id = make_file_id(ocr_file, self.output_file_grp) - cli_process( - gt_file.local_filename, - ocr_file.local_filename, - file_id, - self.output_file_grp, - metrics=metrics, - textequiv_level=textequiv_level, - plain_encoding=plain_encoding, - ) + file_id = make_file_id(ocr_file, self.output_file_grp) + report_prefix = os.path.join(self.output_file_grp, file_id) - # Add reports to the workspace - for report_suffix, mimetype in [ - [".html", "text/html"], - [".json", "application/json"], - ]: - output_file_id = file_id + report_suffix - output_file = next(self.workspace.mets.find_files(ID=output_file_id), None) - if output_file and config.OCRD_EXISTING_OUTPUT != 'OVERWRITE': - raise FileExistsError(f"A file with ID=={output_file_id} already exists {output_file} and neither force nor ignore are set") - self.workspace.add_file( - file_id=output_file_id, - file_grp=self.output_file_grp, - page_id=page_id, - mimetype=mimetype, - local_filename=file_id + report_suffix, + # Process the files + try: + os.mkdir(self.output_file_grp) + except FileExistsError: + pass + cli_process( + gt_file.local_filename, + ocr_file.local_filename, + report_prefix, + metrics=metrics, + textequiv_level=textequiv_level, ) + # Add reports to the workspace + for report_suffix, mimetype in [ + [".html", "text/html"], + [".json", "application/json"], + ]: + self.workspace.add_file( + file_id=file_id + report_suffix, + file_grp=self.output_file_grp, + page_id=page_id, + mimetype=mimetype, + local_filename=report_prefix + report_suffix, + ) + if __name__ == "__main__": ocrd_dinglehopper() diff --git a/src/dinglehopper/tests/data/actevedef_718448162/mets.xml b/src/dinglehopper/tests/data/actevedef_718448162/mets.xml index ed7c4f4..a6804ca 100644 --- a/src/dinglehopper/tests/data/actevedef_718448162/mets.xml +++ b/src/dinglehopper/tests/data/actevedef_718448162/mets.xml @@ -138,17 +138,17 @@ - + - + - + diff --git a/src/dinglehopper/tests/data/line_dirs/basic/gt/a.gt.txt b/src/dinglehopper/tests/data/line_dirs/basic/gt/a.gt.txt deleted file mode 100644 index 484ba93..0000000 --- a/src/dinglehopper/tests/data/line_dirs/basic/gt/a.gt.txt +++ /dev/null @@ -1 +0,0 @@ -This is a test. diff --git a/src/dinglehopper/tests/data/line_dirs/basic/gt/b.gt.txt b/src/dinglehopper/tests/data/line_dirs/basic/gt/b.gt.txt deleted file mode 100644 index fc9bd6a..0000000 --- a/src/dinglehopper/tests/data/line_dirs/basic/gt/b.gt.txt +++ /dev/null @@ -1 +0,0 @@ -Another test. diff --git a/src/dinglehopper/tests/data/line_dirs/basic/ocr/a.some-ocr.txt b/src/dinglehopper/tests/data/line_dirs/basic/ocr/a.some-ocr.txt deleted file mode 100644 index 27cf4bf..0000000 --- a/src/dinglehopper/tests/data/line_dirs/basic/ocr/a.some-ocr.txt +++ /dev/null @@ -1 +0,0 @@ -Tis is a test. diff --git a/src/dinglehopper/tests/data/line_dirs/basic/ocr/b.some-ocr.txt b/src/dinglehopper/tests/data/line_dirs/basic/ocr/b.some-ocr.txt deleted file mode 100644 index 0bc0e40..0000000 --- a/src/dinglehopper/tests/data/line_dirs/basic/ocr/b.some-ocr.txt +++ /dev/null @@ -1 +0,0 @@ -AnÖther test. diff --git a/src/dinglehopper/tests/data/line_dirs/merged/a/a.gt.txt b/src/dinglehopper/tests/data/line_dirs/merged/a/a.gt.txt deleted file mode 100644 index 484ba93..0000000 --- a/src/dinglehopper/tests/data/line_dirs/merged/a/a.gt.txt +++ /dev/null @@ -1 +0,0 @@ -This is a test. diff --git a/src/dinglehopper/tests/data/line_dirs/merged/a/a.some-ocr.txt b/src/dinglehopper/tests/data/line_dirs/merged/a/a.some-ocr.txt deleted file mode 100644 index 27cf4bf..0000000 --- a/src/dinglehopper/tests/data/line_dirs/merged/a/a.some-ocr.txt +++ /dev/null @@ -1 +0,0 @@ -Tis is a test. diff --git a/src/dinglehopper/tests/data/line_dirs/merged/b/b.dummy.jpg b/src/dinglehopper/tests/data/line_dirs/merged/b/b.dummy.jpg deleted file mode 100644 index e69de29..0000000 diff --git a/src/dinglehopper/tests/data/line_dirs/merged/b/b.gt.txt b/src/dinglehopper/tests/data/line_dirs/merged/b/b.gt.txt deleted file mode 100644 index fc9bd6a..0000000 --- a/src/dinglehopper/tests/data/line_dirs/merged/b/b.gt.txt +++ /dev/null @@ -1 +0,0 @@ -Another test. diff --git a/src/dinglehopper/tests/data/line_dirs/merged/b/b.some-ocr.txt b/src/dinglehopper/tests/data/line_dirs/merged/b/b.some-ocr.txt deleted file mode 100644 index 0bc0e40..0000000 --- a/src/dinglehopper/tests/data/line_dirs/merged/b/b.some-ocr.txt +++ /dev/null @@ -1 +0,0 @@ -AnÖther test. diff --git a/src/dinglehopper/tests/data/line_dirs/subdirs/gt/a/a.gt.txt b/src/dinglehopper/tests/data/line_dirs/subdirs/gt/a/a.gt.txt deleted file mode 100644 index 484ba93..0000000 --- a/src/dinglehopper/tests/data/line_dirs/subdirs/gt/a/a.gt.txt +++ /dev/null @@ -1 +0,0 @@ -This is a test. diff --git a/src/dinglehopper/tests/data/line_dirs/subdirs/gt/b/b.gt.txt b/src/dinglehopper/tests/data/line_dirs/subdirs/gt/b/b.gt.txt deleted file mode 100644 index fc9bd6a..0000000 --- a/src/dinglehopper/tests/data/line_dirs/subdirs/gt/b/b.gt.txt +++ /dev/null @@ -1 +0,0 @@ -Another test. diff --git a/src/dinglehopper/tests/data/line_dirs/subdirs/ocr/a/a.some-ocr.txt b/src/dinglehopper/tests/data/line_dirs/subdirs/ocr/a/a.some-ocr.txt deleted file mode 100644 index 27cf4bf..0000000 --- a/src/dinglehopper/tests/data/line_dirs/subdirs/ocr/a/a.some-ocr.txt +++ /dev/null @@ -1 +0,0 @@ -Tis is a test. diff --git a/src/dinglehopper/tests/data/line_dirs/subdirs/ocr/b/b.some-ocr.txt b/src/dinglehopper/tests/data/line_dirs/subdirs/ocr/b/b.some-ocr.txt deleted file mode 100644 index 0bc0e40..0000000 --- a/src/dinglehopper/tests/data/line_dirs/subdirs/ocr/b/b.some-ocr.txt +++ /dev/null @@ -1 +0,0 @@ -AnÖther test. diff --git a/src/dinglehopper/tests/data/test.alto1.xml b/src/dinglehopper/tests/data/test.alto1.xml index 35aa19a..ac2a50b 100644 --- a/src/dinglehopper/tests/data/test.alto1.xml +++ b/src/dinglehopper/tests/data/test.alto1.xml @@ -20183,4 +20183,4 @@ - + \ No newline at end of file diff --git a/src/dinglehopper/tests/data/test.alto2.xml b/src/dinglehopper/tests/data/test.alto2.xml index 39dd592..67d3537 100644 --- a/src/dinglehopper/tests/data/test.alto2.xml +++ b/src/dinglehopper/tests/data/test.alto2.xml @@ -61,4 +61,4 @@ - + \ No newline at end of file diff --git a/src/dinglehopper/tests/data/test.txt b/src/dinglehopper/tests/data/test.txt index 102374b..41bfe81 100644 --- a/src/dinglehopper/tests/data/test.txt +++ b/src/dinglehopper/tests/data/test.txt @@ -1 +1 @@ -Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. +Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. \ No newline at end of file diff --git a/src/dinglehopper/tests/extracted_text_test.py b/src/dinglehopper/tests/extracted_text_test.py index 1a6d99d..ae85735 100644 --- a/src/dinglehopper/tests/extracted_text_test.py +++ b/src/dinglehopper/tests/extracted_text_test.py @@ -13,13 +13,12 @@ def test_text(): test1 = ExtractedText( None, [ - ExtractedText("s0", None, None, "foo", grapheme_clusters("foo")), - ExtractedText("s1", None, None, "bar", grapheme_clusters("bar")), - ExtractedText("s2", None, None, "bazinga", grapheme_clusters("bazinga")), + ExtractedText("s0", None, None, "foo"), + ExtractedText("s1", None, None, "bar"), + ExtractedText("s2", None, None, "bazinga"), ], " ", None, - None, ) assert test1.text == "foo bar bazinga" @@ -30,20 +29,8 @@ def test_text(): def test_normalization_check(): with pytest.raises(ValueError, match=r".*is not in NFC.*"): - ExtractedText( - "foo", - None, - None, - unicodedata.normalize("NFD", "Schlyñ"), - grapheme_clusters(unicodedata.normalize("NFD", "Schlyñ")), - ) - assert ExtractedText( - "foo", - None, - None, - unicodedata.normalize("NFC", "Schlyñ"), - grapheme_clusters(unicodedata.normalize("NFC", "Schlyñ")), - ) + ExtractedText("foo", None, None, unicodedata.normalize("NFD", "Schlyñ")) + assert ExtractedText("foo", None, None, unicodedata.normalize("NFC", "Schlyñ")) AlignmentElement = namedtuple("AlignmentElement", "left right left_id right_id") @@ -60,27 +47,25 @@ def test_align(): test1 = ExtractedText( None, [ - ExtractedText("s0", None, None, "foo", grapheme_clusters("foo")), - ExtractedText("s1", None, None, "bar", grapheme_clusters("bar")), - ExtractedText("s2", None, None, "batzinga", grapheme_clusters("batzinga")), + ExtractedText("s0", None, None, "foo"), + ExtractedText("s1", None, None, "bar"), + ExtractedText("s2", None, None, "batzinga"), ], " ", None, - None, ) test2 = ExtractedText( None, [ - ExtractedText("x0", None, None, "foo", grapheme_clusters("foo")), - ExtractedText("x1", None, None, "bar", grapheme_clusters("bar")), + ExtractedText("x0", None, None, "foo"), + ExtractedText("x1", None, None, "bar"), # extra . - ExtractedText("x2", None, None, ".", grapheme_clusters(".")), + ExtractedText("x2", None, None, "."), # deletion + different grapheme cluster, m̃ also is two Python characters - ExtractedText("x3", None, None, "bazim̃ga", grapheme_clusters("bazim̃ga")), + ExtractedText("x3", None, None, "bazim̃ga"), ], " ", None, - None, ) left_pos = 0 diff --git a/src/dinglehopper/tests/test_align.py b/src/dinglehopper/tests/test_align.py index 5d1e7ab..2c4e23a 100644 --- a/src/dinglehopper/tests/test_align.py +++ b/src/dinglehopper/tests/test_align.py @@ -1,8 +1,6 @@ -import math - import pytest -from .. import align, distance, score_hint, seq_align +from .. import align, distance, seq_align from .util import unzip @@ -185,8 +183,3 @@ def test_lines_similar(): # Test __eq__ (i.e. is it a substitution or a similar string?) assert list(left)[0] == list(right)[0] - - -def test_score_hint(): - assert score_hint(0.5, 23) == 12 # int(ceil()) - assert score_hint(math.inf, 12345) is None diff --git a/src/dinglehopper/tests/test_integ_cli_dir.py b/src/dinglehopper/tests/test_integ_cli_dir.py index 65e59d9..c065130 100644 --- a/src/dinglehopper/tests/test_integ_cli_dir.py +++ b/src/dinglehopper/tests/test_integ_cli_dir.py @@ -21,9 +21,9 @@ def test_cli_directory(tmp_path): os.path.join(data_dir, "directory-test", "ocr"), "report", str(tmp_path / "reports"), - metrics=False, - differences=True, - textequiv_level="line", + False, + True, + "line", ) assert os.path.exists(tmp_path / "reports/1.xml-report.json") @@ -45,9 +45,9 @@ def test_cli_fail_without_gt(tmp_path): os.path.join(data_dir, "directory-test", "ocr"), "report", str(tmp_path / "reports"), - metrics=False, - differences=True, - textequiv_level="line", + False, + True, + "line", ) assert len(os.listdir(tmp_path / "reports")) == 2 * 2 diff --git a/src/dinglehopper/tests/test_integ_cli_line_dirs.py b/src/dinglehopper/tests/test_integ_cli_line_dirs.py deleted file mode 100644 index 90cbabf..0000000 --- a/src/dinglehopper/tests/test_integ_cli_line_dirs.py +++ /dev/null @@ -1,61 +0,0 @@ -import json -import os.path -import re - -import pytest - -from ..cli_line_dirs import process -from .util import working_directory - -data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") - - -@pytest.mark.integration -def test_cli_line_dirs_basic(tmp_path): - """Test that the cli/process() produces a good report""" - - with working_directory(tmp_path): - gt_dir = os.path.join(data_dir, "line_dirs/basic/gt") - ocr_dir = os.path.join(data_dir, "line_dirs/basic/ocr") - process(gt_dir, ocr_dir, "report") - with open("report.json", "r") as jsonf: - print(jsonf.read()) - with open("report.json", "r") as jsonf: - j = json.load(jsonf) - assert j["cer"] == pytest.approx(0.1071429) - assert j["wer"] == pytest.approx(0.5) - - -@pytest.mark.integration -def test_cli_line_dirs_basic_report_diff(tmp_path): - """Test that the cli/process() produces a report wiff char+word diff""" - - with working_directory(tmp_path): - gt_dir = os.path.join(data_dir, "line_dirs/basic/gt") - ocr_dir = os.path.join(data_dir, "line_dirs/basic/ocr") - process(gt_dir, ocr_dir, "report") - - with open("report.html", "r") as htmlf: - html_report = htmlf.read() - - # Counting GT lines in the diff - assert len(re.findall(r"gt.*l\d+-cdiff", html_report)) == 2 - assert len(re.findall(r"gt.*l\d+-wdiff", html_report)) == 2 - - -@pytest.mark.integration -def test_cli_line_dirs_merged(tmp_path): - """Test that the cli/process() produces a good report""" - - with working_directory(tmp_path): - gt_dir = os.path.join(data_dir, "line_dirs/merged") - ocr_dir = os.path.join(data_dir, "line_dirs/merged") - process( - gt_dir, ocr_dir, "report", gt_suffix=".gt.txt", ocr_suffix=".some-ocr.txt" - ) - with open("report.json", "r") as jsonf: - print(jsonf.read()) - with open("report.json", "r") as jsonf: - j = json.load(jsonf) - assert j["cer"] == pytest.approx(0.1071429) - assert j["wer"] == pytest.approx(0.5) diff --git a/src/dinglehopper/tests/test_integ_cli_valid_report.py b/src/dinglehopper/tests/test_integ_cli_valid_json.py similarity index 64% rename from src/dinglehopper/tests/test_integ_cli_valid_report.py rename to src/dinglehopper/tests/test_integ_cli_valid_json.py index fed0d28..6cbfa0c 100644 --- a/src/dinglehopper/tests/test_integ_cli_valid_report.py +++ b/src/dinglehopper/tests/test_integ_cli_valid_json.py @@ -1,5 +1,4 @@ import json -import re import pytest @@ -41,25 +40,3 @@ def test_cli_json_cer_is_infinity(tmp_path): with open("report.json", "r") as jsonf: j = json.load(jsonf) assert j["cer"] == pytest.approx(float("inf")) - - -@pytest.mark.integration -def test_cli_html(tmp_path): - """Test that the cli/process() yields complete HTML report""" - - with working_directory(tmp_path): - with open("gt.txt", "w") as gtf: - gtf.write("AAAAA") - with open("ocr.txt", "w") as ocrf: - ocrf.write("AAAAB") - - process("gt.txt", "ocr.txt", "report") - - with open("report.html", "r") as htmlf: - html_report = htmlf.read() - print(html_report) - - assert re.search(r"CER: 0\.\d+", html_report) - assert re.search(r"WER: 1\.0", html_report) - assert len(re.findall("gt.*cdiff", html_report)) == 1 - assert len(re.findall("gt.*wdiff", html_report)) == 1 diff --git a/src/dinglehopper/tests/test_integ_empty_files.py b/src/dinglehopper/tests/test_integ_empty_files.py deleted file mode 100644 index 5c90ed1..0000000 --- a/src/dinglehopper/tests/test_integ_empty_files.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import division, print_function - -import math - -import pytest - -from .. import character_error_rate, plain_text -from .util import working_directory - - -@pytest.mark.integration -@pytest.mark.parametrize( - "gt_file_content,ocr_file_content,cer_expected", - [ - ("", "Lorem ipsum", math.inf), - ("Lorem ipsum", "", 1.0), - ("\ufeff", "Lorem ipsum", math.inf), - ("Lorem ipsum", "\ufeff", 1.0), - ("", "", 0.0), - ("\ufeff", "", 0.0), - ("", "\ufeff", 0.0), - ], -) -def test_empty_files(tmp_path, gt_file_content, ocr_file_content, cer_expected): - with working_directory(tmp_path): - - with open("gt.txt", "w") as gtf: - gtf.write(gt_file_content) - with open("ocr.txt", "w") as ocrf: - ocrf.write(ocr_file_content) - - gt_text = plain_text("gt.txt") - ocr_text = plain_text("ocr.txt") - - assert character_error_rate(gt_text, ocr_text) == cer_expected diff --git a/src/dinglehopper/tests/test_integ_ocrd_cli.py b/src/dinglehopper/tests/test_integ_ocrd_cli.py index fbda5f4..b30d2b0 100644 --- a/src/dinglehopper/tests/test_integ_ocrd_cli.py +++ b/src/dinglehopper/tests/test_integ_ocrd_cli.py @@ -34,8 +34,9 @@ def test_ocrd_cli(tmp_path): "-O", "OCR-D-OCR-CALAMARI-EVAL", ] - # Hack to satisfy ocrd_cli_wrap_processor() check for arguments - sys.argv[1:] = args + sys.argv[ + 1: + ] = args # XXX Hack to satisfy ocrd_cli_wrap_processor() check for arguments result = runner.invoke(ocrd_dinglehopper, args) assert result.exit_code == 0 result_json = list((test_workspace_dir / "OCR-D-OCR-CALAMARI-EVAL").glob("*.json")) diff --git a/src/dinglehopper/tests/test_line_dirs.py b/src/dinglehopper/tests/test_line_dirs.py deleted file mode 100644 index 03966e1..0000000 --- a/src/dinglehopper/tests/test_line_dirs.py +++ /dev/null @@ -1,71 +0,0 @@ -import os - -from ..cli_line_dirs import find_gt_and_ocr_files, find_gt_and_ocr_files_autodetect - -data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") - - -def test_basic(): - """Test the dumb method: User gives directories and suffixes.""" - pairs = list( - find_gt_and_ocr_files( - os.path.join(data_dir, "line_dirs/basic/gt"), - ".gt.txt", - os.path.join(data_dir, "line_dirs/basic/ocr"), - ".some-ocr.txt", - ) - ) - - assert len(pairs) == 2 - - -def test_basic_autodetect(): - """Test autodetect: User gives directories, suffixes are autodetected if possible""" - pairs = list( - find_gt_and_ocr_files_autodetect( - os.path.join(data_dir, "line_dirs/basic/gt"), - os.path.join(data_dir, "line_dirs/basic/ocr"), - ) - ) - - assert len(pairs) == 2 - - -def test_subdirs(): - """Test the dumb method: Should also work when subdirectories are involved.""" - pairs = list( - find_gt_and_ocr_files( - os.path.join(data_dir, "line_dirs/subdirs/gt"), - ".gt.txt", - os.path.join(data_dir, "line_dirs/subdirs/ocr"), - ".some-ocr.txt", - ) - ) - - assert len(pairs) == 2 - - -def test_subdirs_autodetect(): - """Test the autodetect method: Should also work when subdirectories are involved.""" - pairs = list( - find_gt_and_ocr_files_autodetect( - os.path.join(data_dir, "line_dirs/subdirs/gt"), - os.path.join(data_dir, "line_dirs/subdirs/ocr"), - ) - ) - - assert len(pairs) == 2 - - -def test_merged(): - """Test the dumb method: GT and OCR texts are in the same directories.""" - pairs = list( - find_gt_and_ocr_files( - os.path.join(data_dir, "line_dirs/merged"), - ".gt.txt", - os.path.join(data_dir, "line_dirs/merged"), - ".some-ocr.txt", - ) - ) - - assert len(pairs) == 2 diff --git a/src/dinglehopper/tests/test_ocr_files.py b/src/dinglehopper/tests/test_ocr_files.py index 0c2a500..4790c85 100644 --- a/src/dinglehopper/tests/test_ocr_files.py +++ b/src/dinglehopper/tests/test_ocr_files.py @@ -177,20 +177,8 @@ def test_text(): def test_plain(tmp_path): with working_directory(tmp_path): with open("ocr.txt", "w") as ocrf: - ocrf.write("First, a line.\nAnd a second line.\n") + ocrf.write("AAAAB") result = plain_text("ocr.txt") - expected = "First, a line.\nAnd a second line." - assert result == expected - - -def test_plain_BOM(tmp_path): - """Test that plain text files with BOM are read correctly.""" - BOM = "\ufeff" - with working_directory(tmp_path): - with open("ocr.txt", "w") as ocrf: - ocrf.write(BOM + "First, a line.\nAnd a second line.\n") - - result = plain_text("ocr.txt") - expected = "First, a line.\nAnd a second line." + expected = "AAAAB" assert result == expected diff --git a/src/dinglehopper/word_error_rate.py b/src/dinglehopper/word_error_rate.py index f2db504..8a1c9cb 100644 --- a/src/dinglehopper/word_error_rate.py +++ b/src/dinglehopper/word_error_rate.py @@ -1,5 +1,7 @@ +from __future__ import division + import unicodedata -from typing import Generator, Iterable, Tuple, TypeVar +from typing import Iterable, Tuple import uniseg.wordbreak from multimethod import multimethod @@ -7,8 +9,6 @@ from rapidfuzz.distance import Levenshtein from .extracted_text import ExtractedText -T = TypeVar("T") - # Did we patch uniseg.wordbreak.word_break already? word_break_patched = False @@ -21,17 +21,12 @@ def patch_word_break(): https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt """ old_word_break = uniseg.wordbreak.word_break - if hasattr(uniseg.wordbreak, 'Word_Break'): - aletter = uniseg.wordbreak.Word_Break.ALetter - else: - # uniseg<0.9 - aletter = uniseg.wordbreak.WordBreak.ALETTER - def new_word_break(c): + def new_word_break(c, index=0): if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area - return aletter + return "ALetter" else: - return old_word_break(c) + return old_word_break(c, index) uniseg.wordbreak.word_break = new_word_break global word_break_patched @@ -39,7 +34,7 @@ def patch_word_break(): @multimethod -def words(s: str) -> Generator[str, None, None]: +def words(s: str): """Extract words from a string""" global word_break_patched @@ -59,7 +54,7 @@ def words(s: str) -> Generator[str, None, None]: # We follow Unicode Standard Annex #29 on Unicode Text Segmentation here: Split on # word boundaries using uniseg.wordbreak.words() and ignore all "words" that contain - # only whitespace, punctuation "or similar characters." + # only whitespace, punctation "or similar characters." for word in uniseg.wordbreak.words(s): if all(unwanted(c) for c in word): pass @@ -67,37 +62,37 @@ def words(s: str) -> Generator[str, None, None]: yield word -@words.register -def _(s: ExtractedText) -> Generator[str, None, None]: - yield from words(s.text) +@multimethod +def words(s: ExtractedText): + return words(s.text) @multimethod -def words_normalized(s: str) -> Generator[str, None, None]: - yield from words(unicodedata.normalize("NFC", s)) +def words_normalized(s: str): + return words(unicodedata.normalize("NFC", s)) -@words_normalized.register -def _(s: ExtractedText) -> Generator[str, None, None]: - yield from words_normalized(s.text) +@multimethod +def words_normalized(s: ExtractedText): + return words_normalized(s.text) @multimethod def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: reference_seq = list(words_normalized(reference)) compared_seq = list(words_normalized(compared)) - wer, n = word_error_rate_n(reference_seq, compared_seq) - return wer, n + return word_error_rate_n(reference_seq, compared_seq) -@word_error_rate_n.register -def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: - wer, n = word_error_rate_n(reference.text, compared.text) - return wer, n +@multimethod +def word_error_rate_n( + reference: ExtractedText, compared: ExtractedText +) -> Tuple[float, int]: + return word_error_rate_n(reference.text, compared.text) -@word_error_rate_n.register -def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]: +@multimethod +def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]: reference_seq = list(reference) compared_seq = list(compared) @@ -111,7 +106,6 @@ def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]: return d / n, n -def word_error_rate(reference: T, compared: T) -> float: - wer: float +def word_error_rate(reference, compared) -> float: wer, _ = word_error_rate_n(reference, compared) return wer