move grapheme clusters to ExtractedText

pull/72/head
Max Bachmann 2 years ago
parent f211d09f56
commit 01571f23b7

@ -9,7 +9,7 @@ from .extracted_text import ExtractedText
@multimethod @multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: def character_error_rate_n(reference: list[str], compared: list[str]) -> Tuple[float, int]:
""" """
Compute character error rate. Compute character error rate.
@ -17,7 +17,7 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
""" """
d = distance(reference, compared) d = distance(reference, compared)
n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference)))) n = len(reference)
if d == 0: if d == 0:
return 0, n return 0, n
@ -28,11 +28,18 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
# XXX Should we really count newlines here? # XXX Should we really count newlines here?
@multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
return character_error_rate_n(seq1, seq2)
@multimethod @multimethod
def character_error_rate_n( def character_error_rate_n(
reference: ExtractedText, compared: ExtractedText reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]: ) -> Tuple[float, int]:
return character_error_rate_n(reference.text, compared.text) return character_error_rate_n(reference.grapheme_clusters, compared.grapheme_clusters)
def character_error_rate(reference, compared) -> float: def character_error_rate(reference, compared) -> float:

@ -3,7 +3,6 @@ import os
import click import click
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from markupsafe import escape from markupsafe import escape
from uniseg.graphemecluster import grapheme_clusters
from ocrd_utils import initLogging from ocrd_utils import initLogging
from .character_error_rate import character_error_rate_n from .character_error_rate import character_error_rate_n
@ -45,9 +44,8 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
if isinstance(gt_in, ExtractedText): if isinstance(gt_in, ExtractedText):
if not isinstance(ocr_in, ExtractedText): if not isinstance(ocr_in, ExtractedText):
raise TypeError() raise TypeError()
# XXX splitting should be done in ExtractedText gt_things = gt_in.grapheme_clusters
gt_things = list(grapheme_clusters(gt_in.text)) ocr_things = ocr_in.grapheme_clusters
ocr_things = list(grapheme_clusters(ocr_in.text))
else: else:
gt_things = gt_in gt_things = gt_in
ocr_things = ocr_in ocr_things = ocr_in

@ -7,6 +7,16 @@ from rapidfuzz.distance import Levenshtein
from .extracted_text import ExtractedText from .extracted_text import ExtractedText
@multimethod
def distance(seq1: list[str], seq2: list[str]):
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode
normalization and grapheme clusters. This should be the correct way to compare two
Unicode strings.
"""
return Levenshtein.distance(seq1, seq2)
@multimethod @multimethod
def distance(s1: str, s2: str): def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings
@ -22,7 +32,7 @@ def distance(s1: str, s2: str):
@multimethod @multimethod
def distance(s1: ExtractedText, s2: ExtractedText): def distance(s1: ExtractedText, s2: ExtractedText):
return distance(s1.text, s2.text) return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)
def editops(word1, word2): def editops(word1, word2):

@ -9,6 +9,7 @@ import attr
import numpy as np import numpy as np
from lxml import etree as ET from lxml import etree as ET
from ocrd_utils import getLogger from ocrd_utils import getLogger
from uniseg.graphemecluster import grapheme_clusters
class Normalization(enum.Enum): class Normalization(enum.Enum):
@ -133,6 +134,7 @@ class ExtractedText:
segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list)) segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list))
joiner = attr.ib(type=Optional[str]) joiner = attr.ib(type=Optional[str])
_text = attr.ib(type=Optional[str]) _text = attr.ib(type=Optional[str])
_grapheme_clusters = attr.ib(type=Optional[list[str]])
@segments.validator @segments.validator
def check(self, _, value): def check(self, _, value):
@ -141,12 +143,22 @@ class ExtractedText:
@_text.validator @_text.validator
def check(self, _, value): def check(self, _, value):
if value is not None and self.segments is not None: if value is None:
return
if self.segments is not None:
raise ValueError("Can't have both segments and text") raise ValueError("Can't have both segments and text")
if value is not None and unicodedata.normalize("NFC", value) != value: if unicodedata.normalize("NFC", value) != value:
raise ValueError('String "{}" is not in NFC.'.format(value)) raise ValueError('String "{}" is not in NFC.'.format(value))
if value is not None and normalize(value, self.normalization) != value: if normalize(value, self.normalization) != value:
raise ValueError('String "{}" is not normalized.'.format(value)) raise ValueError('String "{}" is not normalized.'.format(value))
if self._grapheme_clusters is None:
raise ValueError("Requires both text and grapheme clusters to be set")
@_grapheme_clusters.validator
def check(self, _, value):
if value is not None and self._text is None:
raise ValueError("Requires both text and grapheme clusters to be set")
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB) normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
@ -157,6 +169,17 @@ class ExtractedText:
else: else:
return self.joiner.join(s.text for s in self.segments) return self.joiner.join(s.text for s in self.segments)
@property
def grapheme_clusters(self):
if self._text is not None:
return self._grapheme_clusters
else:
clusters = []
for seg in self.segments:
# todo could there be cases where joiner is no grapheme cluster?
clusters.extend(seg.grapheme_clusters + [self.joiner])
return clusters[:-1]
_segment_id_for_pos = None _segment_id_for_pos = None
def segment_id_for_pos(self, pos): def segment_id_for_pos(self, pos):
@ -197,7 +220,8 @@ class ExtractedText:
# FIXME hardcoded SBB normalization # FIXME hardcoded SBB normalization
segment_text = normalize_sbb(segment_text) segment_text = normalize_sbb(segment_text)
segment_text = segment_text or "" segment_text = segment_text or ""
return cls(segment_id, None, None, segment_text) clusters = list(grapheme_clusters(segment_text))
return cls(segment_id, None, None, segment_text, clusters)
else: else:
# Recurse # Recurse
sub_localname = children_for_localname[localname] sub_localname = children_for_localname[localname]
@ -212,12 +236,13 @@ class ExtractedText:
) )
) )
joiner = joiner_for_textequiv_level[sub_textequiv_level] joiner = joiner_for_textequiv_level[sub_textequiv_level]
return cls(segment_id, segments, joiner, None) return cls(segment_id, segments, joiner, None, None)
@classmethod @classmethod
def from_str(cls, text, normalization=Normalization.NFC_SBB): def from_str(cls, text, normalization=Normalization.NFC_SBB):
normalized_text = normalize(text, normalization) normalized_text = normalize(text, normalization)
return cls(None, None, None, normalized_text, normalization=normalization) clusters = list(grapheme_clusters(normalized_text))
return cls(None, None, None, normalized_text, clusters, normalization=normalization)
def invert_dict(d): def invert_dict(d):

@ -4,6 +4,7 @@ from typing import Iterator
from lxml import etree as ET from lxml import etree as ET
from lxml.etree import XMLSyntaxError from lxml.etree import XMLSyntaxError
from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText, normalize_sbb from .extracted_text import ExtractedText, normalize_sbb
@ -29,13 +30,15 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
string.attrib.get("CONTENT") string.attrib.get("CONTENT")
for string in line.iterfind("alto:String", namespaces=nsmap) for string in line.iterfind("alto:String", namespaces=nsmap)
) )
yield ExtractedText(line_id, None, None, normalize_sbb(line_text)) normalized_text = normalize_sbb(line_text)
clusters = list(grapheme_clusters(normalized_text))
yield ExtractedText(line_id, None, None, normalized_text, clusters)
# FIXME hardcoded SBB normalization # FIXME hardcoded SBB normalization
def alto_extract(tree: ET.ElementTree) -> ExtractedText: def alto_extract(tree: ET.ElementTree) -> ExtractedText:
"""Extract text from the given ALTO ElementTree.""" """Extract text from the given ALTO ElementTree."""
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None) return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)
def alto_text(tree): def alto_text(tree):
@ -83,7 +86,7 @@ def page_extract(tree, *, textequiv_level="region"):
# Filter empty region texts # Filter empty region texts
regions = [r for r in regions if r.text != ""] regions = [r for r in regions if r.text != ""]
return ExtractedText(None, regions, "\n", None) return ExtractedText(None, regions, "\n", None, None)
def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level): def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
@ -130,17 +133,21 @@ def page_text(tree, *, textequiv_level="region"):
def plain_extract(filename, include_filename_in_id=False): def plain_extract(filename, include_filename_in_id=False):
id_template = "{filename} - line {no}" if include_filename_in_id else "line {no}" id_template = "{filename} - line {no}" if include_filename_in_id else "line {no}"
def make_segment(no, line):
normalized_text = normalize_sbb(line)
clusters = list(grapheme_clusters(normalized_text))
return ExtractedText(
id_template.format(filename=os.path.basename(filename), no=no),
None, None, normalized_text, clusters)
with open(filename, "r") as f: with open(filename, "r") as f:
return ExtractedText( return ExtractedText(
None, None,
[ [make_segment(no, line) for no, line in enumerate(f.readlines())],
ExtractedText(
id_template.format(filename=os.path.basename(filename), no=no),
None, None, normalize_sbb(line))
for no, line in enumerate(f.readlines())
],
"\n", "\n",
None, None,
None
) )
# XXX hardcoded SBB normalization # XXX hardcoded SBB normalization

Loading…
Cancel
Save