mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-09 11:50:00 +02:00
move grapheme clusters to ExtractedText
This commit is contained in:
parent
f211d09f56
commit
01571f23b7
5 changed files with 70 additions and 23 deletions
|
@ -9,7 +9,7 @@ from .extracted_text import ExtractedText
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
def character_error_rate_n(reference: list[str], compared: list[str]) -> Tuple[float, int]:
|
||||||
"""
|
"""
|
||||||
Compute character error rate.
|
Compute character error rate.
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
d = distance(reference, compared)
|
d = distance(reference, compared)
|
||||||
n = len(list(grapheme_clusters(unicodedata.normalize("NFC", reference))))
|
n = len(reference)
|
||||||
|
|
||||||
if d == 0:
|
if d == 0:
|
||||||
return 0, n
|
return 0, n
|
||||||
|
@ -28,11 +28,18 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
||||||
# XXX Should we really count newlines here?
|
# XXX Should we really count newlines here?
|
||||||
|
|
||||||
|
|
||||||
|
@multimethod
|
||||||
|
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
||||||
|
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
|
||||||
|
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
|
||||||
|
return character_error_rate_n(seq1, seq2)
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def character_error_rate_n(
|
def character_error_rate_n(
|
||||||
reference: ExtractedText, compared: ExtractedText
|
reference: ExtractedText, compared: ExtractedText
|
||||||
) -> Tuple[float, int]:
|
) -> Tuple[float, int]:
|
||||||
return character_error_rate_n(reference.text, compared.text)
|
return character_error_rate_n(reference.grapheme_clusters, compared.grapheme_clusters)
|
||||||
|
|
||||||
|
|
||||||
def character_error_rate(reference, compared) -> float:
|
def character_error_rate(reference, compared) -> float:
|
||||||
|
|
|
@ -3,7 +3,6 @@ import os
|
||||||
import click
|
import click
|
||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
from markupsafe import escape
|
from markupsafe import escape
|
||||||
from uniseg.graphemecluster import grapheme_clusters
|
|
||||||
from ocrd_utils import initLogging
|
from ocrd_utils import initLogging
|
||||||
|
|
||||||
from .character_error_rate import character_error_rate_n
|
from .character_error_rate import character_error_rate_n
|
||||||
|
@ -45,9 +44,8 @@ def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
|
||||||
if isinstance(gt_in, ExtractedText):
|
if isinstance(gt_in, ExtractedText):
|
||||||
if not isinstance(ocr_in, ExtractedText):
|
if not isinstance(ocr_in, ExtractedText):
|
||||||
raise TypeError()
|
raise TypeError()
|
||||||
# XXX splitting should be done in ExtractedText
|
gt_things = gt_in.grapheme_clusters
|
||||||
gt_things = list(grapheme_clusters(gt_in.text))
|
ocr_things = ocr_in.grapheme_clusters
|
||||||
ocr_things = list(grapheme_clusters(ocr_in.text))
|
|
||||||
else:
|
else:
|
||||||
gt_things = gt_in
|
gt_things = gt_in
|
||||||
ocr_things = ocr_in
|
ocr_things = ocr_in
|
||||||
|
|
|
@ -7,6 +7,16 @@ from rapidfuzz.distance import Levenshtein
|
||||||
from .extracted_text import ExtractedText
|
from .extracted_text import ExtractedText
|
||||||
|
|
||||||
|
|
||||||
|
@multimethod
|
||||||
|
def distance(seq1: list[str], seq2: list[str]):
|
||||||
|
"""Compute the Levenshtein edit distance between two Unicode strings
|
||||||
|
|
||||||
|
Note that this is different from levenshtein() as this function knows about Unicode
|
||||||
|
normalization and grapheme clusters. This should be the correct way to compare two
|
||||||
|
Unicode strings.
|
||||||
|
"""
|
||||||
|
return Levenshtein.distance(seq1, seq2)
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def distance(s1: str, s2: str):
|
def distance(s1: str, s2: str):
|
||||||
"""Compute the Levenshtein edit distance between two Unicode strings
|
"""Compute the Levenshtein edit distance between two Unicode strings
|
||||||
|
@ -22,7 +32,7 @@ def distance(s1: str, s2: str):
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def distance(s1: ExtractedText, s2: ExtractedText):
|
def distance(s1: ExtractedText, s2: ExtractedText):
|
||||||
return distance(s1.text, s2.text)
|
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)
|
||||||
|
|
||||||
|
|
||||||
def editops(word1, word2):
|
def editops(word1, word2):
|
||||||
|
|
|
@ -9,6 +9,7 @@ import attr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lxml import etree as ET
|
from lxml import etree as ET
|
||||||
from ocrd_utils import getLogger
|
from ocrd_utils import getLogger
|
||||||
|
from uniseg.graphemecluster import grapheme_clusters
|
||||||
|
|
||||||
|
|
||||||
class Normalization(enum.Enum):
|
class Normalization(enum.Enum):
|
||||||
|
@ -133,6 +134,7 @@ class ExtractedText:
|
||||||
segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list))
|
segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list))
|
||||||
joiner = attr.ib(type=Optional[str])
|
joiner = attr.ib(type=Optional[str])
|
||||||
_text = attr.ib(type=Optional[str])
|
_text = attr.ib(type=Optional[str])
|
||||||
|
_grapheme_clusters = attr.ib(type=Optional[list[str]])
|
||||||
|
|
||||||
@segments.validator
|
@segments.validator
|
||||||
def check(self, _, value):
|
def check(self, _, value):
|
||||||
|
@ -141,12 +143,22 @@ class ExtractedText:
|
||||||
|
|
||||||
@_text.validator
|
@_text.validator
|
||||||
def check(self, _, value):
|
def check(self, _, value):
|
||||||
if value is not None and self.segments is not None:
|
if value is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.segments is not None:
|
||||||
raise ValueError("Can't have both segments and text")
|
raise ValueError("Can't have both segments and text")
|
||||||
if value is not None and unicodedata.normalize("NFC", value) != value:
|
if unicodedata.normalize("NFC", value) != value:
|
||||||
raise ValueError('String "{}" is not in NFC.'.format(value))
|
raise ValueError('String "{}" is not in NFC.'.format(value))
|
||||||
if value is not None and normalize(value, self.normalization) != value:
|
if normalize(value, self.normalization) != value:
|
||||||
raise ValueError('String "{}" is not normalized.'.format(value))
|
raise ValueError('String "{}" is not normalized.'.format(value))
|
||||||
|
if self._grapheme_clusters is None:
|
||||||
|
raise ValueError("Requires both text and grapheme clusters to be set")
|
||||||
|
|
||||||
|
@_grapheme_clusters.validator
|
||||||
|
def check(self, _, value):
|
||||||
|
if value is not None and self._text is None:
|
||||||
|
raise ValueError("Requires both text and grapheme clusters to be set")
|
||||||
|
|
||||||
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
|
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
|
||||||
|
|
||||||
|
@ -157,6 +169,17 @@ class ExtractedText:
|
||||||
else:
|
else:
|
||||||
return self.joiner.join(s.text for s in self.segments)
|
return self.joiner.join(s.text for s in self.segments)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grapheme_clusters(self):
|
||||||
|
if self._text is not None:
|
||||||
|
return self._grapheme_clusters
|
||||||
|
else:
|
||||||
|
clusters = []
|
||||||
|
for seg in self.segments:
|
||||||
|
# todo could there be cases where joiner is no grapheme cluster?
|
||||||
|
clusters.extend(seg.grapheme_clusters + [self.joiner])
|
||||||
|
return clusters[:-1]
|
||||||
|
|
||||||
_segment_id_for_pos = None
|
_segment_id_for_pos = None
|
||||||
|
|
||||||
def segment_id_for_pos(self, pos):
|
def segment_id_for_pos(self, pos):
|
||||||
|
@ -197,7 +220,8 @@ class ExtractedText:
|
||||||
# FIXME hardcoded SBB normalization
|
# FIXME hardcoded SBB normalization
|
||||||
segment_text = normalize_sbb(segment_text)
|
segment_text = normalize_sbb(segment_text)
|
||||||
segment_text = segment_text or ""
|
segment_text = segment_text or ""
|
||||||
return cls(segment_id, None, None, segment_text)
|
clusters = list(grapheme_clusters(segment_text))
|
||||||
|
return cls(segment_id, None, None, segment_text, clusters)
|
||||||
else:
|
else:
|
||||||
# Recurse
|
# Recurse
|
||||||
sub_localname = children_for_localname[localname]
|
sub_localname = children_for_localname[localname]
|
||||||
|
@ -212,12 +236,13 @@ class ExtractedText:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
joiner = joiner_for_textequiv_level[sub_textequiv_level]
|
joiner = joiner_for_textequiv_level[sub_textequiv_level]
|
||||||
return cls(segment_id, segments, joiner, None)
|
return cls(segment_id, segments, joiner, None, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_str(cls, text, normalization=Normalization.NFC_SBB):
|
def from_str(cls, text, normalization=Normalization.NFC_SBB):
|
||||||
normalized_text = normalize(text, normalization)
|
normalized_text = normalize(text, normalization)
|
||||||
return cls(None, None, None, normalized_text, normalization=normalization)
|
clusters = list(grapheme_clusters(normalized_text))
|
||||||
|
return cls(None, None, None, normalized_text, clusters, normalization=normalization)
|
||||||
|
|
||||||
|
|
||||||
def invert_dict(d):
|
def invert_dict(d):
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Iterator
|
||||||
|
|
||||||
from lxml import etree as ET
|
from lxml import etree as ET
|
||||||
from lxml.etree import XMLSyntaxError
|
from lxml.etree import XMLSyntaxError
|
||||||
|
from uniseg.graphemecluster import grapheme_clusters
|
||||||
|
|
||||||
from .extracted_text import ExtractedText, normalize_sbb
|
from .extracted_text import ExtractedText, normalize_sbb
|
||||||
|
|
||||||
|
@ -29,13 +30,15 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
|
||||||
string.attrib.get("CONTENT")
|
string.attrib.get("CONTENT")
|
||||||
for string in line.iterfind("alto:String", namespaces=nsmap)
|
for string in line.iterfind("alto:String", namespaces=nsmap)
|
||||||
)
|
)
|
||||||
yield ExtractedText(line_id, None, None, normalize_sbb(line_text))
|
normalized_text = normalize_sbb(line_text)
|
||||||
|
clusters = list(grapheme_clusters(normalized_text))
|
||||||
|
yield ExtractedText(line_id, None, None, normalized_text, clusters)
|
||||||
# FIXME hardcoded SBB normalization
|
# FIXME hardcoded SBB normalization
|
||||||
|
|
||||||
|
|
||||||
def alto_extract(tree: ET.ElementTree) -> ExtractedText:
|
def alto_extract(tree: ET.ElementTree) -> ExtractedText:
|
||||||
"""Extract text from the given ALTO ElementTree."""
|
"""Extract text from the given ALTO ElementTree."""
|
||||||
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None)
|
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)
|
||||||
|
|
||||||
|
|
||||||
def alto_text(tree):
|
def alto_text(tree):
|
||||||
|
@ -83,7 +86,7 @@ def page_extract(tree, *, textequiv_level="region"):
|
||||||
# Filter empty region texts
|
# Filter empty region texts
|
||||||
regions = [r for r in regions if r.text != ""]
|
regions = [r for r in regions if r.text != ""]
|
||||||
|
|
||||||
return ExtractedText(None, regions, "\n", None)
|
return ExtractedText(None, regions, "\n", None, None)
|
||||||
|
|
||||||
|
|
||||||
def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
|
def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
|
||||||
|
@ -130,17 +133,21 @@ def page_text(tree, *, textequiv_level="region"):
|
||||||
|
|
||||||
def plain_extract(filename, include_filename_in_id=False):
|
def plain_extract(filename, include_filename_in_id=False):
|
||||||
id_template = "{filename} - line {no}" if include_filename_in_id else "line {no}"
|
id_template = "{filename} - line {no}" if include_filename_in_id else "line {no}"
|
||||||
|
|
||||||
|
def make_segment(no, line):
|
||||||
|
normalized_text = normalize_sbb(line)
|
||||||
|
clusters = list(grapheme_clusters(normalized_text))
|
||||||
|
return ExtractedText(
|
||||||
|
id_template.format(filename=os.path.basename(filename), no=no),
|
||||||
|
None, None, normalized_text, clusters)
|
||||||
|
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
return ExtractedText(
|
return ExtractedText(
|
||||||
None,
|
None,
|
||||||
[
|
[make_segment(no, line) for no, line in enumerate(f.readlines())],
|
||||||
ExtractedText(
|
|
||||||
id_template.format(filename=os.path.basename(filename), no=no),
|
|
||||||
None, None, normalize_sbb(line))
|
|
||||||
for no, line in enumerate(f.readlines())
|
|
||||||
],
|
|
||||||
"\n",
|
"\n",
|
||||||
None,
|
None,
|
||||||
|
None
|
||||||
)
|
)
|
||||||
# XXX hardcoded SBB normalization
|
# XXX hardcoded SBB normalization
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue