diff --git a/.idea/dinglehopper.iml b/.idea/dinglehopper.iml index 7c9d48f..0f3d9e5 100644 --- a/.idea/dinglehopper.iml +++ b/.idea/dinglehopper.iml @@ -2,11 +2,10 @@ - + - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 2b68f30..88565d3 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/qurator/dinglehopper/align.py b/qurator/dinglehopper/align.py index ab44760..87febb7 100644 --- a/qurator/dinglehopper/align.py +++ b/qurator/dinglehopper/align.py @@ -28,16 +28,16 @@ def seq_align(s1, s2): if o: if o[0] == 'insert': - yield (None, s2[j]) + yield None, s2[j] j += 1 elif o[0] == 'delete': - yield (s1[i], None) + yield s1[i], None i += 1 elif o[0] == 'replace': - yield (s1[i], s2[j]) + yield s1[i], s2[j] i += 1 j += 1 else: - yield (s1[i], s2[j]) + yield s1[i], s2[j] i += 1 j += 1 diff --git a/qurator/dinglehopper/character_error_rate.py b/qurator/dinglehopper/character_error_rate.py index 05cc931..2b13f55 100644 --- a/qurator/dinglehopper/character_error_rate.py +++ b/qurator/dinglehopper/character_error_rate.py @@ -3,17 +3,21 @@ from __future__ import division import unicodedata from typing import Tuple +from multimethod import multimethod from uniseg.graphemecluster import grapheme_clusters -from qurator.dinglehopper.edit_distance import distance +from .edit_distance import distance +from .extracted_text import ExtractedText -def character_error_rate_n(reference, compared) -> Tuple[float, int]: +@multimethod +def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: """ Compute character error rate. :return: character error rate and length of the reference """ + d = distance(reference, compared) n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference)))) @@ -26,6 +30,11 @@ def character_error_rate_n(reference, compared) -> Tuple[float, int]: # XXX Should we really count newlines here? +@multimethod +def character_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: + return character_error_rate_n(reference.text, compared.text) + + def character_error_rate(reference, compared) -> float: """ Compute character error rate. diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index 759d040..03c35cd 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -3,16 +3,20 @@ import os import click from jinja2 import Environment, FileSystemLoader from markupsafe import escape +from uniseg.graphemecluster import grapheme_clusters +from .character_error_rate import character_error_rate_n +from .word_error_rate import word_error_rate_n, words_normalized +from .align import seq_align +from .extracted_text import ExtractedText +from .ocr_files import extract -from qurator.dinglehopper import * - -def gen_diff_report(gt_things, ocr_things, css_prefix, joiner, none, align): +def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): gtx = '' ocrx = '' - def format_thing(t, css_classes=None): + def format_thing(t, css_classes=None, id_=None): if t is None: html_t = none css_classes += ' ellipsis' @@ -21,19 +25,51 @@ def gen_diff_report(gt_things, ocr_things, css_prefix, joiner, none, align): else: html_t = escape(t) + html_custom_attrs = "" + + # Set Bootstrap tooltip to the segment id + if id_: + html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_) + if css_classes: - return '{html_t}'.format(css_classes=css_classes, html_t=html_t) + return '{html_t}'.format(css_classes=css_classes, html_t=html_t, html_custom_attrs=html_custom_attrs) else: return '{html_t}'.format(html_t=html_t) - for k, (g, o) in enumerate(align(gt_things, ocr_things)): - if g == o: - css_classes = None - else: + if isinstance(gt_in, ExtractedText): + if not isinstance(ocr_in, ExtractedText): + raise TypeError() + # XXX splitting should be done in ExtractedText + gt_things = list(grapheme_clusters(gt_in.text)) + ocr_things = list(grapheme_clusters(ocr_in.text)) + else: + gt_things = gt_in + ocr_things = ocr_in + + + + g_pos = 0 + o_pos = 0 + for k, (g, o) in enumerate(seq_align(gt_things, ocr_things)): + css_classes = None + gt_id = None + ocr_id = None + if g != o: css_classes = '{css_prefix}diff{k} diff'.format(css_prefix=css_prefix, k=k) + if isinstance(gt_in, ExtractedText): + gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None + ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None + # Deletions and inserts only produce one id + None, UI must + # support this, i.e. display for the one id produced + + gtx += joiner + format_thing(g, css_classes, gt_id) + ocrx += joiner + format_thing(o, css_classes, ocr_id) + + if g is not None: + g_pos += len(g) + if o is not None: + o_pos += len(o) - gtx += joiner + format_thing(g, css_classes) - ocrx += joiner + format_thing(o, css_classes) return \ ''' @@ -51,20 +87,17 @@ def process(gt, ocr, report_prefix, *, metrics=True): Click on a wrapper. """ - gt_text = text(gt) - ocr_text = text(ocr) - - gt_text = substitute_equivalences(gt_text) - ocr_text = substitute_equivalences(ocr_text) + gt_text = extract(gt) + ocr_text = extract(ocr) cer, n_characters = character_error_rate_n(gt_text, ocr_text) wer, n_words = word_error_rate_n(gt_text, ocr_text) - char_diff_report = gen_diff_report(gt_text, ocr_text, css_prefix='c', joiner='', none='·', align=align) + char_diff_report = gen_diff_report(gt_text, ocr_text, css_prefix='c', joiner='', none='·') gt_words = words_normalized(gt_text) ocr_words = words_normalized(ocr_text) - word_diff_report = gen_diff_report(gt_words, ocr_words, css_prefix='w', joiner=' ', none='⋯', align=seq_align) + word_diff_report = gen_diff_report(gt_words, ocr_words, css_prefix='w', joiner=' ', none='⋯') def json_float(value): """Convert a float value to an JSON float. diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py index 8ca24d3..ec49338 100644 --- a/qurator/dinglehopper/edit_distance.py +++ b/qurator/dinglehopper/edit_distance.py @@ -5,8 +5,11 @@ from functools import partial, lru_cache from typing import Sequence, Tuple import numpy as np +from multimethod import multimethod from uniseg.graphemecluster import grapheme_clusters +from .extracted_text import ExtractedText + def levenshtein_matrix(seq1: Sequence, seq2: Sequence): """Compute the matrix commonly computed to produce the Levenshtein distance. @@ -69,15 +72,21 @@ def levenshtein_matrix_cache_clear(): _levenshtein_matrix.cache_clear() -def distance(s1, s2): +@multimethod +def distance(s1: str, s2: str): """Compute the Levenshtein edit distance between two Unicode strings Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme clusters. This should be the correct way to compare two Unicode strings. """ - s1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1))) - s2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2))) - return levenshtein(s1, s2) + seq1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1))) + seq2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2))) + return levenshtein(seq1, seq2) + + +@multimethod +def distance(s1: ExtractedText, s2: ExtractedText): + return distance(s1.text, s2.text) def seq_editops(seq1, seq2): @@ -116,7 +125,11 @@ def seq_editops(seq1, seq2): def editops(word1, word2): - # XXX Note that this returns indices to the _grapheme clusters_, not characters! + """ + Return sequence of edit operations transforming one string to another. + + Note that this returns indices to the _grapheme clusters_, not characters! + """ word1 = list(grapheme_clusters(unicodedata.normalize('NFC', word1))) word2 = list(grapheme_clusters(unicodedata.normalize('NFC', word2))) return seq_editops(word1, word2) diff --git a/qurator/dinglehopper/extracted_text.py b/qurator/dinglehopper/extracted_text.py new file mode 100644 index 0000000..6dcd921 --- /dev/null +++ b/qurator/dinglehopper/extracted_text.py @@ -0,0 +1,118 @@ +import enum +import re +import unicodedata +from contextlib import suppress +from itertools import repeat +from typing import Optional + +import attr + +from .substitute_equivalences import substitute_equivalences + + +class Normalization(enum.Enum): + NFC = 1 + NFC_MUFI = 2 # TODO + NFC_SBB = 3 + + +def normalize(text, normalization): + if normalization == Normalization.NFC: + return unicodedata.normalize('NFC', text) + if normalization == Normalization.NFC_MUFI: + raise NotImplementedError() + if normalization == Normalization.NFC_SBB: + return substitute_equivalences(text) + else: + raise ValueError() + + +# XXX hack +def normalize_sbb(t): + return normalize(t, Normalization.NFC_SBB) + + +@attr.s(frozen=True) +class ExtractedText: + """ + Extracted text + + Objects of this class are guaranteed to be a. always in their normalization and + b. in NFC. + """ + segment_id = attr.ib(type=Optional[str]) + + @segment_id.validator + def check(self, _, value): + if value is None: + return + if not re.match(r'[\w\d_-]+', value): + raise ValueError('Malformed segment id "{}"'.format(value)) + + # An object contains either + # a. _text itself + # b. or segments (ExtractedText) and a joiner + + segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list)) + joiner = attr.ib(type=Optional[str]) + _text = attr.ib(type=Optional[str]) + + @segments.validator + def check(self, _, value): + if value is not None and self._text is not None: + raise ValueError("Can't have both segments and text") + + @_text.validator + def check(self, _, value): + if value is not None and self.segments is not None: + raise ValueError("Can't have both segments and text") + if value is not None and unicodedata.normalize('NFC', value) != value: + raise ValueError('String "{}" is not in NFC.'.format(value)) + if value is not None and normalize(value, self.normalization) != value: + raise ValueError('String "{}" is not normalized.'.format(value)) + + normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB) + + @property + def text(self): + if self._text is not None: + if self._text == '': + return None + else: + return self._text + else: + return self.joiner.join(s.text for s in self.segments) + + _segment_id_for_pos = None + + def segment_id_for_pos(self, pos): + # Calculate segment ids once, on the first call + if not self._segment_id_for_pos: + segment_id_for_pos = [] + for s in self.segments: + segment_id_for_pos.extend(repeat(s.segment_id, len(s.text))) + segment_id_for_pos.extend(repeat(None, len(self.joiner))) + segment_id_for_pos = segment_id_for_pos[:-len(self.joiner)] + # This is frozen, so we have to jump through the hoop: + object.__setattr__(self, '_segment_id_for_pos', segment_id_for_pos) + assert self._segment_id_for_pos + + return self._segment_id_for_pos[pos] + + @classmethod + def from_text_segment(cls, text_segment, nsmap): + """Build an ExtractedText from a PAGE content text element""" + + segment_id = text_segment.attrib['id'] + segment_text = None + with suppress(AttributeError): + segment_text = text_segment.find('./page:TextEquiv/page:Unicode', namespaces=nsmap).text + segment_text = segment_text or '' + segment_text = normalize_sbb(segment_text) # FIXME hardcoded SBB normalization + segment_text = segment_text or '' + return cls(segment_id, None, None, segment_text) + + @classmethod + def from_str(cls, text, normalization=Normalization.NFC_SBB): + normalized_text = normalize(text, normalization) + return cls(None, None, None, normalized_text, normalization=normalization) \ No newline at end of file diff --git a/qurator/dinglehopper/ocr_files.py b/qurator/dinglehopper/ocr_files.py index b57a047..78648eb 100644 --- a/qurator/dinglehopper/ocr_files.py +++ b/qurator/dinglehopper/ocr_files.py @@ -1,14 +1,16 @@ from __future__ import division, print_function +from typing import Generator from warnings import warn - -from lxml import etree as ET import sys +from lxml import etree as ET from lxml.etree import XMLSyntaxError +from .extracted_text import ExtractedText, normalize_sbb -def alto_namespace(tree): + +def alto_namespace(tree: ET.ElementTree) -> str: """Return the ALTO namespace used in the given ElementTree. This relies on the assumption that, in any given ALTO file, the root element has the local name "alto". We do not @@ -21,17 +23,22 @@ def alto_namespace(tree): raise ValueError('Not an ALTO tree') -def alto_text(tree): - """Extract text from the given ALTO ElementTree.""" - +def alto_extract_lines(tree: ET.ElementTree) -> Generator[ExtractedText, None, None]: nsmap = {'alto': alto_namespace(tree)} + for line in tree.iterfind('.//alto:TextLine', namespaces=nsmap): + line_id = line.attrib.get('ID') + line_text = ' '.join(string.attrib.get('CONTENT') for string in line.iterfind('alto:String', namespaces=nsmap)) + yield ExtractedText(line_id, None, None, normalize_sbb(line_text)) + # FIXME hardcoded SBB normalization + + +def alto_extract(tree: ET.ElementTree()) -> ExtractedText: + """Extract text from the given ALTO ElementTree.""" + return ExtractedText(None, list(alto_extract_lines(tree)), '\n', None) - lines = ( - ' '.join(string.attrib.get('CONTENT') for string in line.iterfind('alto:String', namespaces=nsmap)) - for line in tree.iterfind('.//alto:TextLine', namespaces=nsmap)) - text_ = '\n'.join(lines) - return text_ +def alto_text(tree): + return alto_extract(tree).text def page_namespace(tree): @@ -47,18 +54,12 @@ def page_namespace(tree): raise ValueError('Not a PAGE tree') -def page_text(tree): +def page_extract(tree): """Extract text from the given PAGE content ElementTree.""" nsmap = {'page': page_namespace(tree)} - def region_text(region): - try: - return region.find('./page:TextEquiv/page:Unicode', namespaces=nsmap).text - except AttributeError: - return None - - region_texts = [] + regions = [] reading_order = tree.find('.//page:ReadingOrder', namespaces=nsmap) if reading_order is not None: for group in reading_order.iterfind('./*', namespaces=nsmap): @@ -68,39 +69,56 @@ def page_text(tree): region_id = region_ref_indexed.attrib['regionRef'] region = tree.find('.//page:TextRegion[@id="%s"]' % region_id, namespaces=nsmap) if region is not None: - region_texts.append(region_text(region)) + regions.append(ExtractedText.from_text_segment(region, nsmap)) else: warn('Not a TextRegion: "%s"' % region_id) else: raise NotImplementedError else: for region in tree.iterfind('.//page:TextRegion', namespaces=nsmap): - region_texts.append(region_text(region)) + regions.append(ExtractedText.from_text_segment(region, nsmap)) - # XXX Does a file have to have regions etc.? region vs lines etc. # Filter empty region texts - region_texts = (t for t in region_texts if t) + regions = [r for r in regions if r.text is not None] - text_ = '\n'.join(region_texts) + return ExtractedText(None, regions, '\n', None) - return text_ +def page_text(tree): + return page_extract(tree).text -def text(filename): - """Read the text from the given file. + +def plain_extract(filename): + with open(filename, 'r') as f: + return ExtractedText( + None, + [ExtractedText('line %d' % no, None, None, line) for no, line in enumerate(f.readlines())], + '\n', + None + ) + + +def plain_text(filename): + return plain_extract(filename).text + + +def extract(filename): + """Extract the text from the given file. Supports PAGE, ALTO and falls back to plain text. """ - try: tree = ET.parse(filename) except XMLSyntaxError: - with open(filename, 'r') as f: - return f.read() + return plain_extract(filename) try: - return page_text(tree) + return page_extract(tree) except ValueError: - return alto_text(tree) + return alto_extract(tree) + + +def text(filename): + return extract(filename).text if __name__ == '__main__': diff --git a/qurator/dinglehopper/ocrd_cli.py b/qurator/dinglehopper/ocrd_cli.py index f121a7d..b4b31e5 100644 --- a/qurator/dinglehopper/ocrd_cli.py +++ b/qurator/dinglehopper/ocrd_cli.py @@ -7,8 +7,8 @@ from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor from ocrd_utils import getLogger, make_file_id, assert_file_grp_cardinality from pkg_resources import resource_string -from qurator.dinglehopper.cli import process as cli_process -from qurator.dinglehopper.edit_distance import levenshtein_matrix_cache_clear +from .cli import process as cli_process +from .edit_distance import levenshtein_matrix_cache_clear OCRD_TOOL = json.loads(resource_string(__name__, 'ocrd-tool.json').decode('utf8')) diff --git a/qurator/dinglehopper/substitute_equivalences.py b/qurator/dinglehopper/substitute_equivalences.py index 1b7e0cf..39be276 100644 --- a/qurator/dinglehopper/substitute_equivalences.py +++ b/qurator/dinglehopper/substitute_equivalences.py @@ -1,21 +1,15 @@ import unicodedata -def substitute_equivalences(s): +def unjoin_ligatures(s): + """Unjoin ligatures, i.e. ff becomes ff.""" - # These are for OCR-D GT vs Tesseract frk vs Calamari GT4HistOCR - # It might make sense to use different rules for GT and for the different OCR equivalences = { - '': 'ü', '': 'ſſ', "\ueba7": 'ſſi', # MUFI: LATIN SMALL LIGATURE LONG S LONG S I - '': 'ä', '': 'ch', - '==': '–', # → en-dash - '—': '–', # em-dash → en-dash '': 'ck', '': 'll', - '': 'ö', '': 'ſi', '': 'ſt', 'fi': 'fi', @@ -23,12 +17,7 @@ def substitute_equivalences(s): 'fl': 'fl', 'ffi': 'ffi', '': 'ct', - '’': '\'', - '⸗': '-', '': 'tz', # MUFI: LATIN SMALL LIGATURE TZ - 'aͤ': 'ä', # LATIN SMALL LETTER A, COMBINING LATIN SMALL LETTER E - 'oͤ': 'ö', # LATIN SMALL LETTER O, COMBINING LATIN SMALL LETTER E - 'uͤ': 'ü', # LATIN SMALL LETTER U, COMBINING LATIN SMALL LETTER E '\uf532': 'as', # eMOP: Latin small ligature as '\uf533': 'is', # eMOP: Latin small ligature is '\uf534': 'us', # eMOP: Latin small ligature us @@ -37,10 +26,32 @@ def substitute_equivalences(s): '\uE8BF': 'q&', # MUFI: LATIN SMALL LETTER Q LIGATED WITH FINAL ET XXX How to replace this correctly? '\uEBA5': 'ſp', # MUFI: LATIN SMALL LIGATURE LONG S P 'st': 'st', # U+FB06 LATIN SMALL LIGATURE ST + } + s = unicodedata.normalize('NFC', s) + for fr, to in equivalences.items(): + s = s.replace(fr, to) + return s + + +def substitute_equivalences(s): + # These are for OCR-D GT vs Tesseract frk vs Calamari GT4HistOCR + # It might make sense to use different rules for GT and for the different OCR + equivalences = { + '': 'ü', + '': 'ä', + '==': '–', # → en-dash + '—': '–', # em-dash → en-dash + '': 'ö', + '’': '\'', + '⸗': '-', + 'aͤ': 'ä', # LATIN SMALL LETTER A, COMBINING LATIN SMALL LETTER E + 'oͤ': 'ö', # LATIN SMALL LETTER O, COMBINING LATIN SMALL LETTER E + 'uͤ': 'ü', # LATIN SMALL LETTER U, COMBINING LATIN SMALL LETTER E '\uF50E': 'q́' # U+F50E LATIN SMALL LETTER Q WITH ACUTE ACCENT } s = unicodedata.normalize('NFC', s) + s = unjoin_ligatures(s) for fr, to in equivalences.items(): s = s.replace(fr, to) return s diff --git a/qurator/dinglehopper/templates/report.html.js b/qurator/dinglehopper/templates/report.html.js index ac43676..4c2ba28 100644 --- a/qurator/dinglehopper/templates/report.html.js +++ b/qurator/dinglehopper/templates/report.html.js @@ -1,14 +1,15 @@ function find_diff_class(classes) { - return classes.split(/\s+/).find(x => x.match(/.diff\d.*/)); + return $('.' + classes.split(/\s+/).find(x => x.match(/.diff\d.*/))); } $(document).ready(function() { + /* Enable Bootstrap tooltips */ + $('[data-toggle="tooltip"]').tooltip(); + $('.diff').mouseover(function() { - let c = find_diff_class($(this).attr('class')) - $('.' + c).addClass('diff-highlight') + find_diff_class($(this).attr('class')).addClass('diff-highlight'); }); $('.diff').mouseout(function() { - let c = find_diff_class($(this).attr('class')) - $('.' + c).removeClass('diff-highlight') + find_diff_class($(this).attr('class')).removeClass('diff-highlight'); }); }); diff --git a/qurator/dinglehopper/tests/extracted_text_test.py b/qurator/dinglehopper/tests/extracted_text_test.py new file mode 100644 index 0000000..98788f6 --- /dev/null +++ b/qurator/dinglehopper/tests/extracted_text_test.py @@ -0,0 +1,68 @@ +import unicodedata +import pytest +from uniseg.graphemecluster import grapheme_clusters +from collections import namedtuple + +from .. import seq_align, ExtractedText + + +def test_text(): + test1 = ExtractedText(None, [ + ExtractedText('s0', None, None, 'foo'), + ExtractedText('s1', None, None, 'bar'), + ExtractedText('s2', None, None, 'bazinga') + ], ' ', None) + + assert test1.text == 'foo bar bazinga' + assert test1.segment_id_for_pos(0) == 's0' + assert test1.segment_id_for_pos(3) is None + assert test1.segment_id_for_pos(10) == 's2' + + +def test_normalization_check(): + with pytest.raises(ValueError, match=r'.*is not in NFC.*'): + ExtractedText('foo', None, None, unicodedata.normalize('NFD', 'Schlyñ')) + assert ExtractedText('foo', None, None, unicodedata.normalize('NFC', 'Schlyñ')) + + +AlignmentElement = namedtuple('AlignmentElement', 'left right left_id right_id') + + +def test_align(): + """ + Test aligning by character while retaining segment id info + + The difficulty here is that aligning should work on grapheme clusters, + not Python characters. + """ + + test1 = ExtractedText(None, [ + ExtractedText('s0', None, None, 'foo'), + ExtractedText('s1', None, None, 'bar'), + ExtractedText('s2', None, None, 'batzinga') + ], ' ', None) + test2 = ExtractedText(None, [ + ExtractedText('x0', None, None, 'foo'), + ExtractedText('x1', None, None, 'bar'), + ExtractedText('x2', None, None, '.'), # extra . + ExtractedText('x3', None, None, 'bazim̃ga'), # deletion + different grapheme cluster, m̃ also is two Python characters + ], ' ', None) + + left_pos = 0; right_pos = 0; alignment = [] + for left, right in seq_align(grapheme_clusters(test1.text), grapheme_clusters(test2.text)): + left_id = test1.segment_id_for_pos(left_pos) if left is not None else None + right_id = test2.segment_id_for_pos(right_pos) if right is not None else None + el = AlignmentElement(left, right, left_id, right_id) + alignment.append(el) + if left is not None: + left_pos += len(left) + if right is not None: + right_pos += len(right) + + print('test1: {}'.format(test1.text)) + print('test2: {}'.format(test2.text)) + + assert alignment[0] == ('f', 'f', 's0', 'x0') + assert alignment[8] == (None, '.', None, 'x2') + assert alignment[12] == ('t', None, 's2', None) + assert alignment[15] == ('n', 'm̃', 's2', 'x3') diff --git a/qurator/dinglehopper/tests/test_align.py b/qurator/dinglehopper/tests/test_align.py index cc5cb43..23483f8 100644 --- a/qurator/dinglehopper/tests/test_align.py +++ b/qurator/dinglehopper/tests/test_align.py @@ -78,7 +78,8 @@ def test_lines(): def test_lines_similar(): - """Test comparing list of lines while using a "weaker equivalence". + """ + Test comparing list of lines while using a "weaker equivalence". This mainly serves as documentation. """ @@ -88,7 +89,14 @@ def test_lines_similar(): self._string = string def __eq__(self, other): - return distance(self._string, other._string) < 2 # XXX NOT the final version + # Just an example! + min_len = min(len(self._string), len(other._string)) + if min_len > 0: + normalized_distance = distance(self._string, other._string)/min_len + similar = normalized_distance < 0.1 + else: + similar = False + return similar def __ne__(self, other): return not self.__eq__(other) @@ -106,3 +114,6 @@ def test_lines_similar(): left, right = unzip(result) assert list(left) == [SimilarString('This is a line.'), SimilarString('This is another'), None, SimilarString('And the last line')] assert list(right) == [SimilarString('This is a ljne.'), SimilarString('This is another'), SimilarString('J u n k'), SimilarString('And the last line')] + + # Test __eq__ (i.e. is it a substitution or a similar string?) + assert list(left)[0] == list(right)[0] diff --git a/qurator/dinglehopper/tests/test_integ_align.py b/qurator/dinglehopper/tests/test_integ_align.py index df1e230..b35974b 100644 --- a/qurator/dinglehopper/tests/test_integ_align.py +++ b/qurator/dinglehopper/tests/test_integ_align.py @@ -13,11 +13,15 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') @pytest.mark.integration def test_align_page_files(): # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. - # → 4 elements in the alignment should be different. + # → 2 elements in the alignment should be different, the ligature is + # (currently) not counted due to normalization. # NOTE: In this example, it doesn't matter that we work with "characters", not grapheme clusters. gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml'))) ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml'))) result = list(align(gt, ocr)) - assert sum(left != right for left, right in result) == 4 + for left, right in result: + if left != right: + print(left, right) + assert sum(left != right for left, right in result) == 2 diff --git a/qurator/dinglehopper/tests/test_integ_character_error_rate_ocr.py b/qurator/dinglehopper/tests/test_integ_character_error_rate_ocr.py index c27cd31..1c3bf52 100644 --- a/qurator/dinglehopper/tests/test_integ_character_error_rate_ocr.py +++ b/qurator/dinglehopper/tests/test_integ_character_error_rate_ocr.py @@ -4,6 +4,7 @@ import os import pytest from lxml import etree as ET +from uniseg.graphemecluster import grapheme_clusters from .. import character_error_rate, page_text, alto_text @@ -13,9 +14,14 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') @pytest.mark.integration def test_character_error_rate_between_page_files(): # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. + # The fi ligature does not count. gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml'))) ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml'))) - assert character_error_rate(gt, ocr) == 4/(470 + 1 + 311) # 2 TextRegions, 1 \n + + gt_len = len(list(grapheme_clusters(gt))) + expected_cer = 2/gt_len + + assert character_error_rate(gt, ocr) == expected_cer @pytest.mark.integration diff --git a/qurator/dinglehopper/tests/test_integ_cli_valid_json.py b/qurator/dinglehopper/tests/test_integ_cli_valid_json.py index 5699700..d71bc14 100644 --- a/qurator/dinglehopper/tests/test_integ_cli_valid_json.py +++ b/qurator/dinglehopper/tests/test_integ_cli_valid_json.py @@ -1,4 +1,3 @@ -import os import json import pytest @@ -10,14 +9,17 @@ from ..cli import process def test_cli_json(tmp_path): """Test that the cli/process() yields a loadable JSON report""" - # XXX Path.__str__() is necessary for Python 3.5 with working_directory(str(tmp_path)): with open('gt.txt', 'w') as gtf: gtf.write('AAAAA') with open('ocr.txt', 'w') as ocrf: ocrf.write('AAAAB') + with open('gt.txt', 'r') as gtf: + print(gtf.read()) process('gt.txt', 'ocr.txt', 'report') + with open('report.json', 'r') as jsonf: + print(jsonf.read()) with open('report.json', 'r') as jsonf: j = json.load(jsonf) assert j['cer'] == pytest.approx(0.2) @@ -26,7 +28,6 @@ def test_cli_json(tmp_path): def test_cli_json_cer_is_infinity(tmp_path): """Test that the cli/process() yields a loadable JSON report when CER == inf""" - # XXX Path.__str__() is necessary for Python 3.5 with working_directory(str(tmp_path)): with open('gt.txt', 'w') as gtf: gtf.write('') # Empty to yield CER == inf diff --git a/qurator/dinglehopper/tests/test_integ_edit_distance_ocr.py b/qurator/dinglehopper/tests/test_integ_edit_distance_ocr.py index 2857d56..cbe12f8 100644 --- a/qurator/dinglehopper/tests/test_integ_edit_distance_ocr.py +++ b/qurator/dinglehopper/tests/test_integ_edit_distance_ocr.py @@ -13,9 +13,11 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') @pytest.mark.integration def test_distance_between_page_files(): # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. + # Due to normalization, we don't count the ligature. + # → 2 differences gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml'))) ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml'))) - assert distance(gt, ocr) == 4 + assert distance(gt, ocr) == 2 @pytest.mark.integration diff --git a/qurator/dinglehopper/tests/test_integ_ocrd_cli.py b/qurator/dinglehopper/tests/test_integ_ocrd_cli.py index 8eb07c0..5e535b5 100644 --- a/qurator/dinglehopper/tests/test_integ_ocrd_cli.py +++ b/qurator/dinglehopper/tests/test_integ_ocrd_cli.py @@ -1,12 +1,10 @@ import os -import re import shutil import json import sys from pathlib import Path from click.testing import CliRunner -import pytest from .util import working_directory @@ -18,8 +16,6 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') def test_ocrd_cli(tmp_path): """Test OCR-D interface""" - # XXX Path.str() is necessary for Python 3.5 - # Copy test workspace test_workspace_dir_source = Path(data_dir) / 'actevedef_718448162' test_workspace_dir = tmp_path / 'test_ocrd_cli' diff --git a/qurator/dinglehopper/tests/test_integ_word_error_rate_ocr.py b/qurator/dinglehopper/tests/test_integ_word_error_rate_ocr.py index 1d2dead..f5c922b 100644 --- a/qurator/dinglehopper/tests/test_integ_word_error_rate_ocr.py +++ b/qurator/dinglehopper/tests/test_integ_word_error_rate_ocr.py @@ -12,14 +12,15 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') @pytest.mark.integration def test_word_error_rate_between_page_files(): - # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. → 3 changed words + # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. So we have 3 changed words, + # the ligature does not count → 2 errors gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml'))) gt_word_count = 7+6+5+8+7+6+7+8+6+7+7+5+6+8+8+7+7+6+5+4 # Manually verified word count per line assert len(list(words(gt))) == gt_word_count ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml'))) - assert word_error_rate(gt, ocr) == 3/gt_word_count + assert word_error_rate(gt, ocr) == 2/gt_word_count @pytest.mark.integration diff --git a/qurator/dinglehopper/tests/test_ocr_files.py b/qurator/dinglehopper/tests/test_ocr_files.py index dd9377a..3291152 100644 --- a/qurator/dinglehopper/tests/test_ocr_files.py +++ b/qurator/dinglehopper/tests/test_ocr_files.py @@ -6,7 +6,8 @@ import textwrap import pytest -from .. import alto_namespace, alto_text, page_namespace, page_text, text +from .util import working_directory +from .. import alto_namespace, alto_text, page_namespace, page_text, plain_text, text data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') @@ -49,27 +50,51 @@ def test_page_namespace(): def test_page_test(): tree = ET.parse(os.path.join(data_dir, 'test.page2018.xml')) result = page_text(tree) + + # We are currently normalizing on extraction, so the text is normalized. + # + # expected = textwrap.dedent("""\ + # ber die vielen Sorgen wegen deelben vergaß + # Hartkopf, der Frau Amtmnnin das ver⸗ + # ſproene zu berliefern. — Ein Erpreer + # wurde an ihn abgeſit, um ihn ums Him⸗ + # melswien zu ſagen, daß er das Verſproene + # glei den Augenbli berbringen mte, die + # Frau Amtmnnin htte  auf ihn verlaen, + # und nun wßte e nit, was e anfangen + # ſote. Den Augenbli ſote er kommen, + # ſon vergieng e in ihrer Ang. — Die + # Ge wren ſon angekommen, und es fehlte + # ihr do no an aem. — + # Hartkopf mußte  er bennen, und + # endli na langem Nadenken fiel es ihm er + # wieder ein. — Er langte den Zettel aus dem + # Accisbue heraus, und ſagte ſeiner Frau, daß + # e das, was da wre, herbeyſaffen mte. + # Jndeß mangelten do einige Generalia, die + # alſo wegfielen. — Hartkopf gieng ſelb + # mit und berbrate es. —""") expected = textwrap.dedent("""\ - ber die vielen Sorgen wegen deelben vergaß - Hartkopf, der Frau Amtmnnin das ver⸗ - ſproene zu berliefern. — Ein Erpreer - wurde an ihn abgeſit, um ihn ums Him⸗ - melswien zu ſagen, daß er das Verſproene - glei den Augenbli berbringen mte, die - Frau Amtmnnin htte  auf ihn verlaen, - und nun wßte e nit, was e anfangen - ſote. Den Augenbli ſote er kommen, - ſon vergieng e in ihrer Ang. — Die - Ge wren ſon angekommen, und es fehlte - ihr do no an aem. — - Hartkopf mußte  er bennen, und - endli na langem Nadenken fiel es ihm er - wieder ein. — Er langte den Zettel aus dem - Accisbue heraus, und ſagte ſeiner Frau, daß - e das, was da wre, herbeyſaffen mte. - Jndeß mangelten do einige Generalia, die - alſo wegfielen. — Hartkopf gieng ſelb - mit und berbrate es. —""") + über die vielen Sorgen wegen deſſelben vergaß + Hartkopf, der Frau Amtmännin das ver- + ſprochene zu überliefern. – Ein Erpreſſer + wurde an ihn abgeſchickt, um ihn ums Him- + melswillen zu ſagen, daß er das Verſprochene + gleich den Augenblick überbringen möchte, die + Frau Amtmännin hätte ſich auf ihn verlaſſen, + und nun wüßte ſie nicht, was ſie anfangen + ſollte. Den Augenblick ſollte er kommen, + ſonſt vergieng ſie in ihrer Angſt. – Die + Gäſte wären ſchon angekommen, und es fehlte + ihr doch noch an allem. – + Hartkopf mußte ſich erſt beſinnen, und + endlich nach langem Nachdenken fiel es ihm erſt + wieder ein. – Er langte den Zettel aus dem + Accisbuche heraus, und ſagte ſeiner Frau, daß + ſie das, was da wäre, herbeyſchaffen möchte. + Jndeß mangelten doch einige Generalia, die + alſo wegfielen. – Hartkopf gieng ſelbſt + mit und überbrachte es. –""") assert result == expected @@ -92,7 +117,8 @@ def test_page_order(): tree = ET.parse(os.path.join(data_dir, 'order.page.xml')) result = page_text(tree) - assert re.search(r'Herr Konfrater.*75.*Etwas f.r Wittwen.*Ein gewi.er Lord.*76\. Die', result, re.DOTALL) + print(result) + assert re.search(r'Herr Konfrater.*75.*Etwas f.r Wittwen.*Ein gewi.{1,2}er Lord.*76\. Die', result, re.DOTALL) def test_page_mixed_regions(): @@ -106,5 +132,15 @@ def test_page_mixed_regions(): def test_text(): assert "being erected at the Broadway stock" in text(os.path.join(data_dir, 'test.alto1.xml')) - assert "wieder ein. — Er langte den Zettel aus dem" in text(os.path.join(data_dir, 'test.page2018.xml')) + assert "wieder ein. – Er langte den Zettel aus dem" in text(os.path.join(data_dir, 'test.page2018.xml')) assert "Lorem ipsum" in text(os.path.join(data_dir, 'test.txt')) + + +def test_plain(tmp_path): + with working_directory(str(tmp_path)): + with open('ocr.txt', 'w') as ocrf: + ocrf.write('AAAAB') + + result = plain_text('ocr.txt') + expected = 'AAAAB' + assert result == expected diff --git a/qurator/dinglehopper/tests/util.py b/qurator/dinglehopper/tests/util.py index 52b7506..1f224e5 100644 --- a/qurator/dinglehopper/tests/util.py +++ b/qurator/dinglehopper/tests/util.py @@ -21,8 +21,8 @@ def diffprint(x, y): _diffprint(x, y) -def unzip(l): - return zip(*l) +def unzip(an_iterable_of_tuples): + return zip(*an_iterable_of_tuples) class working_directory: diff --git a/qurator/dinglehopper/word_error_rate.py b/qurator/dinglehopper/word_error_rate.py index 7ed56e4..2f5a1f6 100644 --- a/qurator/dinglehopper/word_error_rate.py +++ b/qurator/dinglehopper/word_error_rate.py @@ -1,14 +1,19 @@ from __future__ import division import unicodedata -from typing import Tuple +from typing import Tuple, Iterable +from multimethod import multimethod import uniseg.wordbreak from .edit_distance import levenshtein +from . import ExtractedText -def words(s): +@multimethod +def words(s: str): + """Extract words from a string""" + # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt old_word_break = uniseg.wordbreak.word_break @@ -41,17 +46,37 @@ def words(s): yield word -def words_normalized(s): +@multimethod +def words(s: ExtractedText): + return words(s.text) + + +@multimethod +def words_normalized(s: str): return words(unicodedata.normalize('NFC', s)) -def word_error_rate_n(reference, compared) -> Tuple[float, int]: - if isinstance(reference, str): - reference_seq = list(words_normalized(reference)) - compared_seq = list(words_normalized(compared)) - else: - reference_seq = list(reference) - compared_seq = list(compared) +@multimethod +def words_normalized(s: ExtractedText): + return words_normalized(s.text) + + +@multimethod +def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: + reference_seq = list(words_normalized(reference)) + compared_seq = list(words_normalized(compared)) + return word_error_rate_n(reference_seq, compared_seq) + + +@multimethod +def word_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]: + return word_error_rate_n(reference.text, compared.text) + + +@multimethod +def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]: + reference_seq = list(reference) + compared_seq = list(compared) d = levenshtein(reference_seq, compared_seq) n = len(reference_seq) diff --git a/requirements.txt b/requirements.txt index 6dd4079..c2e47dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,5 @@ numpy colorama MarkupSafe ocrd >= 2.13.1 +attrs +multimethod == 1.3 # latest version to officially support Python 3.5