mirror of
https://github.com/qurator-spk/dinglehopper.git
synced 2025-06-09 11:50:00 +02:00
apply black
This commit is contained in:
parent
01571f23b7
commit
22c3817f45
7 changed files with 32 additions and 14 deletions
|
@ -9,7 +9,9 @@ from .extracted_text import ExtractedText
|
|||
|
||||
|
||||
@multimethod
|
||||
def character_error_rate_n(reference: list[str], compared: list[str]) -> Tuple[float, int]:
|
||||
def character_error_rate_n(
|
||||
reference: list[str], compared: list[str]
|
||||
) -> Tuple[float, int]:
|
||||
"""
|
||||
Compute character error rate.
|
||||
|
||||
|
@ -39,7 +41,9 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
|
|||
def character_error_rate_n(
|
||||
reference: ExtractedText, compared: ExtractedText
|
||||
) -> Tuple[float, int]:
|
||||
return character_error_rate_n(reference.grapheme_clusters, compared.grapheme_clusters)
|
||||
return character_error_rate_n(
|
||||
reference.grapheme_clusters, compared.grapheme_clusters
|
||||
)
|
||||
|
||||
|
||||
def character_error_rate(reference, compared) -> float:
|
||||
|
|
|
@ -26,7 +26,7 @@ def common_suffix(its):
|
|||
|
||||
def removesuffix(text, suffix):
|
||||
if suffix and text.endswith(suffix):
|
||||
return text[:-len(suffix)]
|
||||
return text[: -len(suffix)]
|
||||
return text
|
||||
|
||||
|
||||
|
@ -46,7 +46,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
|
|||
ocr = removesuffix(gt, gt_suffix) + ocr_suffix
|
||||
|
||||
gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True)
|
||||
ocr_text = plain_extract(os.path.join(ocr_dir, ocr), include_filename_in_id=True)
|
||||
ocr_text = plain_extract(
|
||||
os.path.join(ocr_dir, ocr), include_filename_in_id=True
|
||||
)
|
||||
gt_words = words_normalized(gt_text)
|
||||
ocr_words = words_normalized(ocr_text)
|
||||
|
||||
|
@ -56,7 +58,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
|
|||
cer, n_characters = l_cer, l_n_characters
|
||||
else:
|
||||
# Rolling update
|
||||
cer = (cer * n_characters + l_cer * l_n_characters) / (n_characters + l_n_characters)
|
||||
cer = (cer * n_characters + l_cer * l_n_characters) / (
|
||||
n_characters + l_n_characters
|
||||
)
|
||||
n_characters = n_characters + l_n_characters
|
||||
|
||||
# Compute WER
|
||||
|
|
|
@ -17,6 +17,7 @@ def distance(seq1: list[str], seq2: list[str]):
|
|||
"""
|
||||
return Levenshtein.distance(seq1, seq2)
|
||||
|
||||
|
||||
@multimethod
|
||||
def distance(s1: str, s2: str):
|
||||
"""Compute the Levenshtein edit distance between two Unicode strings
|
||||
|
|
|
@ -175,7 +175,7 @@ class ExtractedText:
|
|||
return self._grapheme_clusters
|
||||
else:
|
||||
clusters = []
|
||||
for seg in self.segments:
|
||||
for seg in self.segments:
|
||||
# todo could there be cases where joiner is no grapheme cluster?
|
||||
clusters.extend(seg.grapheme_clusters + [self.joiner])
|
||||
return clusters[:-1]
|
||||
|
@ -242,7 +242,9 @@ class ExtractedText:
|
|||
def from_str(cls, text, normalization=Normalization.NFC_SBB):
|
||||
normalized_text = normalize(text, normalization)
|
||||
clusters = list(grapheme_clusters(normalized_text))
|
||||
return cls(None, None, None, normalized_text, clusters, normalization=normalization)
|
||||
return cls(
|
||||
None, None, None, normalized_text, clusters, normalization=normalization
|
||||
)
|
||||
|
||||
|
||||
def invert_dict(d):
|
||||
|
|
|
@ -98,14 +98,18 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
|
|||
|
||||
ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children)
|
||||
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
|
||||
elif ET.QName(group.tag).localname in ["UnorderedGroup","UnorderedGroupIndexed"]:
|
||||
elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
|
||||
ro_children = list(group)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
for ro_child in ro_children:
|
||||
if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]:
|
||||
if ET.QName(ro_child.tag).localname in [
|
||||
"OrderedGroup",
|
||||
"OrderedGroupIndexed",
|
||||
"UnorderedGroup",
|
||||
"UnorderedGroupIndexed",
|
||||
]:
|
||||
regions.extend(
|
||||
extract_texts_from_reading_order_group(
|
||||
ro_child, tree, nsmap, textequiv_level
|
||||
|
@ -139,7 +143,11 @@ def plain_extract(filename, include_filename_in_id=False):
|
|||
clusters = list(grapheme_clusters(normalized_text))
|
||||
return ExtractedText(
|
||||
id_template.format(filename=os.path.basename(filename), no=no),
|
||||
None, None, normalized_text, clusters)
|
||||
None,
|
||||
None,
|
||||
normalized_text,
|
||||
clusters,
|
||||
)
|
||||
|
||||
with open(filename, "r") as f:
|
||||
return ExtractedText(
|
||||
|
@ -147,7 +155,7 @@ def plain_extract(filename, include_filename_in_id=False):
|
|||
[make_segment(no, line) for no, line in enumerate(f.readlines())],
|
||||
"\n",
|
||||
None,
|
||||
None
|
||||
None,
|
||||
)
|
||||
# XXX hardcoded SBB normalization
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class OcrdDinglehopperEvaluate(Processor):
|
|||
textequiv_level = self.parameter["textequiv_level"]
|
||||
gt_grp, ocr_grp = self.input_file_grp.split(",")
|
||||
|
||||
input_file_tuples = self.zip_input_files(on_error='abort')
|
||||
input_file_tuples = self.zip_input_files(on_error="abort")
|
||||
for n, (gt_file, ocr_file) in enumerate(input_file_tuples):
|
||||
if not gt_file or not ocr_file:
|
||||
# file/page was not found in this group
|
||||
|
|
|
@ -40,7 +40,6 @@ def words(s: str):
|
|||
if not word_break_patched:
|
||||
patch_word_break()
|
||||
|
||||
|
||||
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
|
||||
def unwanted(c):
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue