move grapheme clusters to ExtractedText

pull/72/head
Max Bachmann 2 years ago
parent f211d09f56
commit 01571f23b7

@ -9,7 +9,7 @@ from .extracted_text import ExtractedText
@multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
def character_error_rate_n(reference: list[str], compared: list[str]) -> Tuple[float, int]:
"""
Compute character error rate.
@ -17,7 +17,7 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
"""
d = distance(reference, compared)
n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference))))
n = len(reference)
if d == 0:
return 0, n
@ -28,11 +28,18 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
# XXX Should we really count newlines here?
@multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
return character_error_rate_n(seq1, seq2)
@multimethod
def character_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
return character_error_rate_n(reference.text, compared.text)
return character_error_rate_n(reference.grapheme_clusters, compared.grapheme_clusters)
def character_error_rate(reference, compared) -> float:

@ -3,7 +3,6 @@ import os
import click
from jinja2 import Environment, FileSystemLoader
from markupsafe import escape
from uniseg.graphemecluster import grapheme_clusters
from ocrd_utils import initLogging
from .character_error_rate import character_error_rate_n
@ -45,9 +44,8 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
if isinstance(gt_in, ExtractedText):
if not isinstance(ocr_in, ExtractedText):
raise TypeError()
# XXX splitting should be done in ExtractedText
gt_things = list(grapheme_clusters(gt_in.text))
ocr_things = list(grapheme_clusters(ocr_in.text))
gt_things = gt_in.grapheme_clusters
ocr_things = ocr_in.grapheme_clusters
else:
gt_things = gt_in
ocr_things = ocr_in

@ -7,6 +7,16 @@ from rapidfuzz.distance import Levenshtein
from .extracted_text import ExtractedText
@multimethod
def distance(seq1: list[str], seq2: list[str]):
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode
normalization and grapheme clusters. This should be the correct way to compare two
Unicode strings.
"""
return Levenshtein.distance(seq1, seq2)
@multimethod
def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings
@ -22,7 +32,7 @@ def distance(s1: str, s2: str):
@multimethod
def distance(s1: ExtractedText, s2: ExtractedText):
return distance(s1.text, s2.text)
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)
def editops(word1, word2):

@ -9,6 +9,7 @@ import attr
import numpy as np
from lxml import etree as ET
from ocrd_utils import getLogger
from uniseg.graphemecluster import grapheme_clusters
class Normalization(enum.Enum):
@ -133,6 +134,7 @@ class ExtractedText:
segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list))
joiner = attr.ib(type=Optional[str])
_text = attr.ib(type=Optional[str])
_grapheme_clusters = attr.ib(type=Optional[list[str]])
@segments.validator
def check(self, _, value):
@ -141,12 +143,22 @@ class ExtractedText:
@_text.validator
def check(self, _, value):
if value is not None and self.segments is not None:
if value is None:
return
if self.segments is not None:
raise ValueError("Can't have both segments and text")
if value is not None and unicodedata.normalize("NFC", value) != value:
if unicodedata.normalize("NFC", value) != value:
raise ValueError('String "{}" is not in NFC.'.format(value))
if value is not None and normalize(value, self.normalization) != value:
if normalize(value, self.normalization) != value:
raise ValueError('String "{}" is not normalized.'.format(value))
if self._grapheme_clusters is None:
raise ValueError("Requires both text and grapheme clusters to be set")
@_grapheme_clusters.validator
def check(self, _, value):
if value is not None and self._text is None:
raise ValueError("Requires both text and grapheme clusters to be set")
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
@ -157,6 +169,17 @@ class ExtractedText:
else:
return self.joiner.join(s.text for s in self.segments)
@property
def grapheme_clusters(self):
if self._text is not None:
return self._grapheme_clusters
else:
clusters = []
for seg in self.segments:
# todo could there be cases where joiner is no grapheme cluster?
clusters.extend(seg.grapheme_clusters + [self.joiner])
return clusters[:-1]
_segment_id_for_pos = None
def segment_id_for_pos(self, pos):
@ -197,7 +220,8 @@ class ExtractedText:
# FIXME hardcoded SBB normalization
segment_text = normalize_sbb(segment_text)
segment_text = segment_text or ""
return cls(segment_id, None, None, segment_text)
clusters = list(grapheme_clusters(segment_text))
return cls(segment_id, None, None, segment_text, clusters)
else:
# Recurse
sub_localname = children_for_localname[localname]
@ -212,12 +236,13 @@ class ExtractedText:
)
)
joiner = joiner_for_textequiv_level[sub_textequiv_level]
return cls(segment_id, segments, joiner, None)
return cls(segment_id, segments, joiner, None, None)
@classmethod
def from_str(cls, text, normalization=Normalization.NFC_SBB):
normalized_text = normalize(text, normalization)
return cls(None, None, None, normalized_text, normalization=normalization)
clusters = list(grapheme_clusters(normalized_text))
return cls(None, None, None, normalized_text, clusters, normalization=normalization)
def invert_dict(d):

@ -4,6 +4,7 @@ from typing import Iterator
from lxml import etree as ET
from lxml.etree import XMLSyntaxError
from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText, normalize_sbb
@ -29,13 +30,15 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
string.attrib.get("CONTENT")
for string in line.iterfind("alto:String", namespaces=nsmap)
)
yield ExtractedText(line_id, None, None, normalize_sbb(line_text))
normalized_text = normalize_sbb(line_text)
clusters = list(grapheme_clusters(normalized_text))
yield ExtractedText(line_id, None, None, normalized_text, clusters)
# FIXME hardcoded SBB normalization
def alto_extract(tree: ET.ElementTree) -> ExtractedText:
"""Extract text from the given ALTO ElementTree."""
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None)
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)
def alto_text(tree):
@ -83,7 +86,7 @@ def page_extract(tree, *, textequiv_level="region"):
# Filter empty region texts
regions = [r for r in regions if r.text != ""]
return ExtractedText(None, regions, "\n", None)
return ExtractedText(None, regions, "\n", None, None)
def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
@ -130,17 +133,21 @@ def page_text(tree, *, textequiv_level="region"):
def plain_extract(filename, include_filename_in_id=False):
id_template = "{filename} - line {no}" if include_filename_in_id else "line {no}"
def make_segment(no, line):
normalized_text = normalize_sbb(line)
clusters = list(grapheme_clusters(normalized_text))
return ExtractedText(
id_template.format(filename=os.path.basename(filename), no=no),
None, None, normalized_text, clusters)
with open(filename, "r") as f:
return ExtractedText(
None,
[
ExtractedText(
id_template.format(filename=os.path.basename(filename), no=no),
None, None, normalize_sbb(line))
for no, line in enumerate(f.readlines())
],
[make_segment(no, line) for no, line in enumerate(f.readlines())],
"\n",
None,
None
)
# XXX hardcoded SBB normalization

Loading…
Cancel
Save