mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-09 20:00:01 +02:00
🔍 mypy: Use an almost strict mypy configuration, and fix any issues
This commit is contained in:
parent
ad316aeabc
commit
483e809691
11 changed files with 77 additions and 41 deletions
|
@ -28,6 +28,8 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- additional_dependencies:
|
- additional_dependencies:
|
||||||
- types-setuptools
|
- types-setuptools
|
||||||
|
- types-lxml
|
||||||
|
- numpy # for numpy plugin
|
||||||
id: mypy
|
id: mypy
|
||||||
|
|
||||||
- repo: https://gitlab.com/vojko.pribudic/pre-commit-update
|
- repo: https://gitlab.com/vojko.pribudic/pre-commit-update
|
||||||
|
|
|
@ -60,9 +60,20 @@ markers = [
|
||||||
|
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
|
plugins = ["numpy.typing.mypy_plugin"]
|
||||||
|
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
|
||||||
|
strict = true
|
||||||
|
|
||||||
|
disallow_subclassing_any = false
|
||||||
|
# ❗ error: Class cannot subclass "Processor" (has type "Any")
|
||||||
|
disallow_any_generics = false
|
||||||
|
disallow_untyped_defs = false
|
||||||
|
disallow_untyped_calls = false
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
select = ["E", "F", "I"]
|
select = ["E", "F", "I"]
|
||||||
ignore = [
|
ignore = [
|
||||||
|
|
|
@ -7,5 +7,6 @@ ruff
|
||||||
pytest-ruff
|
pytest-ruff
|
||||||
|
|
||||||
mypy
|
mypy
|
||||||
|
types-lxml
|
||||||
types-setuptools
|
types-setuptools
|
||||||
pytest-mypy
|
pytest-mypy
|
||||||
|
|
|
@ -4,8 +4,7 @@ from math import ceil
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from rapidfuzz.distance import Levenshtein
|
from rapidfuzz.distance import Levenshtein
|
||||||
|
from uniseg.graphemecluster import grapheme_clusters
|
||||||
from .edit_distance import grapheme_clusters
|
|
||||||
|
|
||||||
|
|
||||||
def align(t1, t2):
|
def align(t1, t2):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, TypeVar
|
||||||
|
|
||||||
from multimethod import multimethod
|
from multimethod import multimethod
|
||||||
from uniseg.graphemecluster import grapheme_clusters
|
from uniseg.graphemecluster import grapheme_clusters
|
||||||
|
@ -7,6 +7,8 @@ from uniseg.graphemecluster import grapheme_clusters
|
||||||
from .edit_distance import distance
|
from .edit_distance import distance
|
||||||
from .extracted_text import ExtractedText
|
from .extracted_text import ExtractedText
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def character_error_rate_n(
|
def character_error_rate_n(
|
||||||
|
@ -34,21 +36,24 @@ def character_error_rate_n(
|
||||||
def _(reference: str, compared: str) -> Tuple[float, int]:
|
def _(reference: str, compared: str) -> Tuple[float, int]:
|
||||||
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
|
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
|
||||||
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
|
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
|
||||||
return character_error_rate_n(seq1, seq2)
|
cer, n = character_error_rate_n(seq1, seq2)
|
||||||
|
return cer, n
|
||||||
|
|
||||||
|
|
||||||
@character_error_rate_n.register
|
@character_error_rate_n.register
|
||||||
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
|
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
|
||||||
return character_error_rate_n(
|
cer, n = character_error_rate_n(
|
||||||
reference.grapheme_clusters, compared.grapheme_clusters
|
reference.grapheme_clusters, compared.grapheme_clusters
|
||||||
)
|
)
|
||||||
|
return cer, n
|
||||||
|
|
||||||
|
|
||||||
def character_error_rate(reference, compared) -> float:
|
def character_error_rate(reference: T, compared: T) -> float:
|
||||||
"""
|
"""
|
||||||
Compute character error rate.
|
Compute character error rate.
|
||||||
|
|
||||||
:return: character error rate
|
:return: character error rate
|
||||||
"""
|
"""
|
||||||
|
cer: float
|
||||||
cer, _ = character_error_rate_n(reference, compared)
|
cer, _ = character_error_rate_n(reference, compared)
|
||||||
return cer
|
return cer
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
@ -76,7 +77,7 @@ def gen_diff_report(
|
||||||
if o is not None:
|
if o is not None:
|
||||||
o_pos += len(o)
|
o_pos += len(o)
|
||||||
|
|
||||||
found_differences = dict(Counter(elem for elem in found_differences))
|
counted_differences = dict(Counter(elem for elem in found_differences))
|
||||||
|
|
||||||
return (
|
return (
|
||||||
"""
|
"""
|
||||||
|
@ -87,7 +88,7 @@ def gen_diff_report(
|
||||||
""".format(
|
""".format(
|
||||||
gtx, ocrx
|
gtx, ocrx
|
||||||
),
|
),
|
||||||
found_differences,
|
counted_differences,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,7 +114,7 @@ def process(
|
||||||
metrics: bool = True,
|
metrics: bool = True,
|
||||||
differences: bool = False,
|
differences: bool = False,
|
||||||
textequiv_level: str = "region",
|
textequiv_level: str = "region",
|
||||||
):
|
) -> None:
|
||||||
"""Check OCR result against GT.
|
"""Check OCR result against GT.
|
||||||
|
|
||||||
The @click decorators change the signature of the decorated functions, so we keep
|
The @click decorators change the signature of the decorated functions, so we keep
|
||||||
|
@ -122,8 +123,8 @@ def process(
|
||||||
|
|
||||||
gt_text = extract(gt, textequiv_level=textequiv_level)
|
gt_text = extract(gt, textequiv_level=textequiv_level)
|
||||||
ocr_text = extract(ocr, textequiv_level=textequiv_level)
|
ocr_text = extract(ocr, textequiv_level=textequiv_level)
|
||||||
gt_words: list = list(words_normalized(gt_text))
|
gt_words: List[str] = list(words_normalized(gt_text))
|
||||||
ocr_words: list = list(words_normalized(ocr_text))
|
ocr_words: List[str] = list(words_normalized(ocr_text))
|
||||||
|
|
||||||
assert isinstance(gt_text, ExtractedText)
|
assert isinstance(gt_text, ExtractedText)
|
||||||
assert isinstance(ocr_text, ExtractedText)
|
assert isinstance(ocr_text, ExtractedText)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
@ -13,8 +14,8 @@ def process(reports_folder, occurrences_threshold=1):
|
||||||
wer_list = []
|
wer_list = []
|
||||||
cer_sum = 0
|
cer_sum = 0
|
||||||
wer_sum = 0
|
wer_sum = 0
|
||||||
diff_c = {}
|
diff_c: Dict[str, int] = {}
|
||||||
diff_w = {}
|
diff_w: Dict[str, int] = {}
|
||||||
|
|
||||||
for report in os.listdir(reports_folder):
|
for report in os.listdir(reports_folder):
|
||||||
if report.endswith(".json"):
|
if report.endswith(".json"):
|
||||||
|
|
|
@ -9,7 +9,7 @@ from .extracted_text import ExtractedText
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def distance(seq1: List[str], seq2: List[str]):
|
def distance(seq1: List[str], seq2: List[str]) -> int:
|
||||||
"""Compute the Levenshtein edit distance between two lists of grapheme clusters.
|
"""Compute the Levenshtein edit distance between two lists of grapheme clusters.
|
||||||
|
|
||||||
This assumes that the grapheme clusters are already normalized.
|
This assumes that the grapheme clusters are already normalized.
|
||||||
|
@ -20,7 +20,7 @@ def distance(seq1: List[str], seq2: List[str]):
|
||||||
|
|
||||||
|
|
||||||
@distance.register
|
@distance.register
|
||||||
def _(s1: str, s2: str):
|
def _(s1: str, s2: str) -> int:
|
||||||
"""Compute the Levenshtein edit distance between two Unicode strings
|
"""Compute the Levenshtein edit distance between two Unicode strings
|
||||||
|
|
||||||
Note that this is different from levenshtein() as this function knows about Unicode
|
Note that this is different from levenshtein() as this function knows about Unicode
|
||||||
|
@ -33,7 +33,7 @@ def _(s1: str, s2: str):
|
||||||
|
|
||||||
|
|
||||||
@distance.register
|
@distance.register
|
||||||
def _(s1: ExtractedText, s2: ExtractedText):
|
def _(s1: ExtractedText, s2: ExtractedText) -> int:
|
||||||
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)
|
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -173,10 +173,11 @@ class ExtractedText:
|
||||||
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
|
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self):
|
def text(self) -> str:
|
||||||
if self._text is not None:
|
if self._text is not None:
|
||||||
return self._text
|
return self._text
|
||||||
else:
|
else:
|
||||||
|
assert self.joiner is not None and self.segments is not None
|
||||||
return self.joiner.join(s.text for s in self.segments)
|
return self.joiner.join(s.text for s in self.segments)
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
|
@ -186,6 +187,7 @@ class ExtractedText:
|
||||||
This property is cached.
|
This property is cached.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert self.joiner is not None
|
||||||
if len(self.joiner) > 0:
|
if len(self.joiner) > 0:
|
||||||
joiner_grapheme_cluster = list(grapheme_clusters(self.joiner))
|
joiner_grapheme_cluster = list(grapheme_clusters(self.joiner))
|
||||||
assert len(joiner_grapheme_cluster) == 1 # see joiner's check above
|
assert len(joiner_grapheme_cluster) == 1 # see joiner's check above
|
||||||
|
@ -203,6 +205,7 @@ class ExtractedText:
|
||||||
else:
|
else:
|
||||||
# TODO Test with text extracted at glyph level (joiner == "")
|
# TODO Test with text extracted at glyph level (joiner == "")
|
||||||
clusters = []
|
clusters = []
|
||||||
|
assert self.segments is not None
|
||||||
for seg in self.segments:
|
for seg in self.segments:
|
||||||
clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster
|
clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster
|
||||||
clusters = clusters[:-1]
|
clusters = clusters[:-1]
|
||||||
|
@ -218,6 +221,7 @@ class ExtractedText:
|
||||||
else:
|
else:
|
||||||
# Recurse
|
# Recurse
|
||||||
segment_id_for_pos = []
|
segment_id_for_pos = []
|
||||||
|
assert self.joiner is not None and self.segments is not None
|
||||||
for s in self.segments:
|
for s in self.segments:
|
||||||
seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))]
|
seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))]
|
||||||
segment_id_for_pos.extend(seg_ids)
|
segment_id_for_pos.extend(seg_ids)
|
||||||
|
@ -280,7 +284,7 @@ def invert_dict(d):
|
||||||
return {v: k for k, v in d.items()}
|
return {v: k for k, v in d.items()}
|
||||||
|
|
||||||
|
|
||||||
def get_textequiv_unicode(text_segment, nsmap) -> str:
|
def get_textequiv_unicode(text_segment: Any, nsmap: Dict[str, str]) -> str:
|
||||||
"""Get the TextEquiv/Unicode text of the given PAGE text element."""
|
"""Get the TextEquiv/Unicode text of the given PAGE text element."""
|
||||||
segment_id = text_segment.attrib["id"]
|
segment_id = text_segment.attrib["id"]
|
||||||
textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap)
|
textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap)
|
||||||
|
@ -304,7 +308,7 @@ def get_first_textequiv(textequivs, segment_id):
|
||||||
if np.any(~nan_mask):
|
if np.any(~nan_mask):
|
||||||
if np.any(nan_mask):
|
if np.any(nan_mask):
|
||||||
log.warning("TextEquiv without index in %s.", segment_id)
|
log.warning("TextEquiv without index in %s.", segment_id)
|
||||||
index = np.nanargmin(indices)
|
index = int(np.nanargmin(indices))
|
||||||
else:
|
else:
|
||||||
# try ordering by conf
|
# try ordering by conf
|
||||||
confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float)
|
confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float)
|
||||||
|
@ -313,7 +317,7 @@ def get_first_textequiv(textequivs, segment_id):
|
||||||
"No index attributes, use 'conf' attribute to sort TextEquiv in %s.",
|
"No index attributes, use 'conf' attribute to sort TextEquiv in %s.",
|
||||||
segment_id,
|
segment_id,
|
||||||
)
|
)
|
||||||
index = np.nanargmax(confidences)
|
index = int(np.nanargmax(confidences))
|
||||||
else:
|
else:
|
||||||
# fallback to first entry in case of neither index or conf present
|
# fallback to first entry in case of neither index or conf present
|
||||||
log.warning("No index attributes, use first TextEquiv in %s.", segment_id)
|
log.warning("No index attributes, use first TextEquiv in %s.", segment_id)
|
||||||
|
@ -321,7 +325,7 @@ def get_first_textequiv(textequivs, segment_id):
|
||||||
return textequivs[index]
|
return textequivs[index]
|
||||||
|
|
||||||
|
|
||||||
def get_attr(te, attr_name) -> float:
|
def get_attr(te: Any, attr_name: str) -> float:
|
||||||
"""Extract the attribute for the given name.
|
"""Extract the attribute for the given name.
|
||||||
|
|
||||||
Note: currently only handles numeric values!
|
Note: currently only handles numeric values!
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Iterator
|
from typing import Dict, Iterator, Optional
|
||||||
|
|
||||||
import chardet
|
import chardet
|
||||||
from lxml import etree as ET
|
from lxml import etree as ET
|
||||||
|
@ -10,11 +10,11 @@ from uniseg.graphemecluster import grapheme_clusters
|
||||||
from .extracted_text import ExtractedText, normalize_sbb
|
from .extracted_text import ExtractedText, normalize_sbb
|
||||||
|
|
||||||
|
|
||||||
def alto_namespace(tree: ET.ElementTree) -> str:
|
def alto_namespace(tree: ET._ElementTree) -> Optional[str]:
|
||||||
"""Return the ALTO namespace used in the given ElementTree.
|
"""Return the ALTO namespace used in the given ElementTree.
|
||||||
|
|
||||||
This relies on the assumption that, in any given ALTO file, the root element has the
|
This relies on the assumption that, in any given ALTO file, the root element has the
|
||||||
local name "alto". We do not check if the files uses any valid ALTO namespace.
|
local name "alto". We do not check if the file uses any valid ALTO namespace.
|
||||||
"""
|
"""
|
||||||
root_name = ET.QName(tree.getroot().tag)
|
root_name = ET.QName(tree.getroot().tag)
|
||||||
if root_name.localname == "alto":
|
if root_name.localname == "alto":
|
||||||
|
@ -23,8 +23,15 @@ def alto_namespace(tree: ET.ElementTree) -> str:
|
||||||
raise ValueError("Not an ALTO tree")
|
raise ValueError("Not an ALTO tree")
|
||||||
|
|
||||||
|
|
||||||
def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
|
def alto_nsmap(tree: ET._ElementTree) -> Dict[str, str]:
|
||||||
nsmap = {"alto": alto_namespace(tree)}
|
alto_ns = alto_namespace(tree)
|
||||||
|
if alto_ns is None:
|
||||||
|
raise ValueError("Could not determine ALTO namespace")
|
||||||
|
return {"alto": alto_ns}
|
||||||
|
|
||||||
|
|
||||||
|
def alto_extract_lines(tree: ET._ElementTree) -> Iterator[ExtractedText]:
|
||||||
|
nsmap = alto_nsmap(tree)
|
||||||
for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap):
|
for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap):
|
||||||
line_id = line.attrib.get("ID")
|
line_id = line.attrib.get("ID")
|
||||||
line_text = " ".join(
|
line_text = " ".join(
|
||||||
|
@ -37,7 +44,7 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
|
||||||
# FIXME hardcoded SBB normalization
|
# FIXME hardcoded SBB normalization
|
||||||
|
|
||||||
|
|
||||||
def alto_extract(tree: ET.ElementTree) -> ExtractedText:
|
def alto_extract(tree: ET._ElementTree) -> ExtractedText:
|
||||||
"""Extract text from the given ALTO ElementTree."""
|
"""Extract text from the given ALTO ElementTree."""
|
||||||
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)
|
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)
|
||||||
|
|
||||||
|
@ -98,7 +105,7 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
|
||||||
if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]:
|
if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]:
|
||||||
ro_children = list(group)
|
ro_children = list(group)
|
||||||
|
|
||||||
ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children)
|
ro_children = [child for child in ro_children if "index" in child.attrib.keys()]
|
||||||
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
|
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
|
||||||
elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
|
elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
|
||||||
ro_children = list(group)
|
ro_children = list(group)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from typing import Iterable, Tuple
|
from typing import Generator, Iterable, Tuple, TypeVar
|
||||||
|
|
||||||
import uniseg.wordbreak
|
import uniseg.wordbreak
|
||||||
from multimethod import multimethod
|
from multimethod import multimethod
|
||||||
|
@ -7,6 +7,8 @@ from rapidfuzz.distance import Levenshtein
|
||||||
|
|
||||||
from .extracted_text import ExtractedText
|
from .extracted_text import ExtractedText
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
# Did we patch uniseg.wordbreak.word_break already?
|
# Did we patch uniseg.wordbreak.word_break already?
|
||||||
word_break_patched = False
|
word_break_patched = False
|
||||||
|
|
||||||
|
@ -32,7 +34,7 @@ def patch_word_break():
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def words(s: str):
|
def words(s: str) -> Generator[str, None, None]:
|
||||||
"""Extract words from a string"""
|
"""Extract words from a string"""
|
||||||
|
|
||||||
global word_break_patched
|
global word_break_patched
|
||||||
|
@ -61,34 +63,36 @@ def words(s: str):
|
||||||
|
|
||||||
|
|
||||||
@words.register
|
@words.register
|
||||||
def _(s: ExtractedText):
|
def _(s: ExtractedText) -> Generator[str, None, None]:
|
||||||
return words(s.text)
|
yield from words(s.text)
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def words_normalized(s: str):
|
def words_normalized(s: str) -> Generator[str, None, None]:
|
||||||
return words(unicodedata.normalize("NFC", s))
|
yield from words(unicodedata.normalize("NFC", s))
|
||||||
|
|
||||||
|
|
||||||
@words_normalized.register
|
@words_normalized.register
|
||||||
def _(s: ExtractedText):
|
def _(s: ExtractedText) -> Generator[str, None, None]:
|
||||||
return words_normalized(s.text)
|
yield from words_normalized(s.text)
|
||||||
|
|
||||||
|
|
||||||
@multimethod
|
@multimethod
|
||||||
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
||||||
reference_seq = list(words_normalized(reference))
|
reference_seq = list(words_normalized(reference))
|
||||||
compared_seq = list(words_normalized(compared))
|
compared_seq = list(words_normalized(compared))
|
||||||
return word_error_rate_n(reference_seq, compared_seq)
|
wer, n = word_error_rate_n(reference_seq, compared_seq)
|
||||||
|
return wer, n
|
||||||
|
|
||||||
|
|
||||||
@word_error_rate_n.register
|
@word_error_rate_n.register
|
||||||
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
|
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
|
||||||
return word_error_rate_n(reference.text, compared.text)
|
wer, n = word_error_rate_n(reference.text, compared.text)
|
||||||
|
return wer, n
|
||||||
|
|
||||||
|
|
||||||
@word_error_rate_n.register
|
@word_error_rate_n.register
|
||||||
def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
|
def _(reference: Iterable[T], compared: Iterable[T]) -> Tuple[float, int]:
|
||||||
reference_seq = list(reference)
|
reference_seq = list(reference)
|
||||||
compared_seq = list(compared)
|
compared_seq = list(compared)
|
||||||
|
|
||||||
|
@ -102,6 +106,7 @@ def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
|
||||||
return d / n, n
|
return d / n, n
|
||||||
|
|
||||||
|
|
||||||
def word_error_rate(reference, compared) -> float:
|
def word_error_rate(reference: T, compared: T) -> float:
|
||||||
|
wer: float
|
||||||
wer, _ = word_error_rate_n(reference, compared)
|
wer, _ = word_error_rate_n(reference, compared)
|
||||||
return wer
|
return wer
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue