From 01571f23b74ca4e13e0934b8ac0cf0100956e31c Mon Sep 17 00:00:00 2001 From: Max Bachmann Date: Mon, 29 Aug 2022 01:49:04 +0200 Subject: [PATCH] move grapheme clusters to ExtractedText --- qurator/dinglehopper/character_error_rate.py | 13 +++++-- qurator/dinglehopper/cli.py | 6 ++-- qurator/dinglehopper/edit_distance.py | 12 ++++++- qurator/dinglehopper/extracted_text.py | 37 ++++++++++++++++---- qurator/dinglehopper/ocr_files.py | 25 ++++++++----- 5 files changed, 70 insertions(+), 23 deletions(-) diff --git a/qurator/dinglehopper/character_error_rate.py b/qurator/dinglehopper/character_error_rate.py index 2128a9f..7116660 100644 --- a/qurator/dinglehopper/character_error_rate.py +++ b/qurator/dinglehopper/character_error_rate.py @@ -9,7 +9,7 @@ from .extracted_text import ExtractedText @multimethod -def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: +def character_error_rate_n(reference: list[str], compared: list[str]) -> Tuple[float, int]: """ Compute character error rate. @@ -17,7 +17,7 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: """ d = distance(reference, compared) - n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference)))) + n = len(reference) if d == 0: return 0, n @@ -28,11 +28,18 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: # XXX Should we really count newlines here? +@multimethod +def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: + seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference))) + seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared))) + return character_error_rate_n(seq1, seq2) + + @multimethod def character_error_rate_n( reference: ExtractedText, compared: ExtractedText ) -> Tuple[float, int]: - return character_error_rate_n(reference.text, compared.text) + return character_error_rate_n(reference.grapheme_clusters, compared.grapheme_clusters) def character_error_rate(reference, compared) -> float: diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index be6f020..3c52c5d 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -3,7 +3,6 @@ import os import click from jinja2 import Environment, FileSystemLoader from markupsafe import escape -from uniseg.graphemecluster import grapheme_clusters from ocrd_utils import initLogging from .character_error_rate import character_error_rate_n @@ -45,9 +44,8 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): if isinstance(gt_in, ExtractedText): if not isinstance(ocr_in, ExtractedText): raise TypeError() - # XXX splitting should be done in ExtractedText - gt_things = list(grapheme_clusters(gt_in.text)) - ocr_things = list(grapheme_clusters(ocr_in.text)) + gt_things = gt_in.grapheme_clusters + ocr_things = ocr_in.grapheme_clusters else: gt_things = gt_in ocr_things = ocr_in diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py index 3adb059..ad8eaf2 100644 --- a/qurator/dinglehopper/edit_distance.py +++ b/qurator/dinglehopper/edit_distance.py @@ -7,6 +7,16 @@ from rapidfuzz.distance import Levenshtein from .extracted_text import ExtractedText +@multimethod +def distance(seq1: list[str], seq2: list[str]): + """Compute the Levenshtein edit distance between two Unicode strings + + Note that this is different from levenshtein() as this function knows about Unicode + normalization and grapheme clusters. This should be the correct way to compare two + Unicode strings. + """ + return Levenshtein.distance(seq1, seq2) + @multimethod def distance(s1: str, s2: str): """Compute the Levenshtein edit distance between two Unicode strings @@ -22,7 +32,7 @@ def distance(s1: str, s2: str): @multimethod def distance(s1: ExtractedText, s2: ExtractedText): - return distance(s1.text, s2.text) + return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters) def editops(word1, word2): diff --git a/qurator/dinglehopper/extracted_text.py b/qurator/dinglehopper/extracted_text.py index 9703b6b..0ddebf5 100644 --- a/qurator/dinglehopper/extracted_text.py +++ b/qurator/dinglehopper/extracted_text.py @@ -9,6 +9,7 @@ import attr import numpy as np from lxml import etree as ET from ocrd_utils import getLogger +from uniseg.graphemecluster import grapheme_clusters class Normalization(enum.Enum): @@ -133,6 +134,7 @@ class ExtractedText: segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list)) joiner = attr.ib(type=Optional[str]) _text = attr.ib(type=Optional[str]) + _grapheme_clusters = attr.ib(type=Optional[list[str]]) @segments.validator def check(self, _, value): @@ -141,12 +143,22 @@ class ExtractedText: @_text.validator def check(self, _, value): - if value is not None and self.segments is not None: + if value is None: + return + + if self.segments is not None: raise ValueError("Can't have both segments and text") - if value is not None and unicodedata.normalize("NFC", value) != value: + if unicodedata.normalize("NFC", value) != value: raise ValueError('String "{}" is not in NFC.'.format(value)) - if value is not None and normalize(value, self.normalization) != value: + if normalize(value, self.normalization) != value: raise ValueError('String "{}" is not normalized.'.format(value)) + if self._grapheme_clusters is None: + raise ValueError("Requires both text and grapheme clusters to be set") + + @_grapheme_clusters.validator + def check(self, _, value): + if value is not None and self._text is None: + raise ValueError("Requires both text and grapheme clusters to be set") normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB) @@ -157,6 +169,17 @@ class ExtractedText: else: return self.joiner.join(s.text for s in self.segments) + @property + def grapheme_clusters(self): + if self._text is not None: + return self._grapheme_clusters + else: + clusters = [] + for seg in self.segments: + # todo could there be cases where joiner is no grapheme cluster? + clusters.extend(seg.grapheme_clusters + [self.joiner]) + return clusters[:-1] + _segment_id_for_pos = None def segment_id_for_pos(self, pos): @@ -197,7 +220,8 @@ class ExtractedText: # FIXME hardcoded SBB normalization segment_text = normalize_sbb(segment_text) segment_text = segment_text or "" - return cls(segment_id, None, None, segment_text) + clusters = list(grapheme_clusters(segment_text)) + return cls(segment_id, None, None, segment_text, clusters) else: # Recurse sub_localname = children_for_localname[localname] @@ -212,12 +236,13 @@ class ExtractedText: ) ) joiner = joiner_for_textequiv_level[sub_textequiv_level] - return cls(segment_id, segments, joiner, None) + return cls(segment_id, segments, joiner, None, None) @classmethod def from_str(cls, text, normalization=Normalization.NFC_SBB): normalized_text = normalize(text, normalization) - return cls(None, None, None, normalized_text, normalization=normalization) + clusters = list(grapheme_clusters(normalized_text)) + return cls(None, None, None, normalized_text, clusters, normalization=normalization) def invert_dict(d): diff --git a/qurator/dinglehopper/ocr_files.py b/qurator/dinglehopper/ocr_files.py index 92f4fe5..38190da 100644 --- a/qurator/dinglehopper/ocr_files.py +++ b/qurator/dinglehopper/ocr_files.py @@ -4,6 +4,7 @@ from typing import Iterator from lxml import etree as ET from lxml.etree import XMLSyntaxError +from uniseg.graphemecluster import grapheme_clusters from .extracted_text import ExtractedText, normalize_sbb @@ -29,13 +30,15 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]: string.attrib.get("CONTENT") for string in line.iterfind("alto:String", namespaces=nsmap) ) - yield ExtractedText(line_id, None, None, normalize_sbb(line_text)) + normalized_text = normalize_sbb(line_text) + clusters = list(grapheme_clusters(normalized_text)) + yield ExtractedText(line_id, None, None, normalized_text, clusters) # FIXME hardcoded SBB normalization def alto_extract(tree: ET.ElementTree) -> ExtractedText: """Extract text from the given ALTO ElementTree.""" - return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None) + return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None) def alto_text(tree): @@ -83,7 +86,7 @@ def page_extract(tree, *, textequiv_level="region"): # Filter empty region texts regions = [r for r in regions if r.text != ""] - return ExtractedText(None, regions, "\n", None) + return ExtractedText(None, regions, "\n", None, None) def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level): @@ -130,17 +133,21 @@ def page_text(tree, *, textequiv_level="region"): def plain_extract(filename, include_filename_in_id=False): id_template = "{filename} - line {no}" if include_filename_in_id else "line {no}" + + def make_segment(no, line): + normalized_text = normalize_sbb(line) + clusters = list(grapheme_clusters(normalized_text)) + return ExtractedText( + id_template.format(filename=os.path.basename(filename), no=no), + None, None, normalized_text, clusters) + with open(filename, "r") as f: return ExtractedText( None, - [ - ExtractedText( - id_template.format(filename=os.path.basename(filename), no=no), - None, None, normalize_sbb(line)) - for no, line in enumerate(f.readlines()) - ], + [make_segment(no, line) for no, line in enumerate(f.readlines())], "\n", None, + None ) # XXX hardcoded SBB normalization