diff --git a/qurator/dinglehopper/character_error_rate.py b/qurator/dinglehopper/character_error_rate.py index 7116660..3b8c0cc 100644 --- a/qurator/dinglehopper/character_error_rate.py +++ b/qurator/dinglehopper/character_error_rate.py @@ -9,7 +9,9 @@ from .extracted_text import ExtractedText @multimethod -def character_error_rate_n(reference: list[str], compared: list[str]) -> Tuple[float, int]: +def character_error_rate_n( + reference: list[str], compared: list[str] +) -> Tuple[float, int]: """ Compute character error rate. @@ -39,7 +41,9 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]: def character_error_rate_n( reference: ExtractedText, compared: ExtractedText ) -> Tuple[float, int]: - return character_error_rate_n(reference.grapheme_clusters, compared.grapheme_clusters) + return character_error_rate_n( + reference.grapheme_clusters, compared.grapheme_clusters + ) def character_error_rate(reference, compared) -> float: diff --git a/qurator/dinglehopper/cli_line_dirs.py b/qurator/dinglehopper/cli_line_dirs.py index 59c4a1f..3f8e3fc 100644 --- a/qurator/dinglehopper/cli_line_dirs.py +++ b/qurator/dinglehopper/cli_line_dirs.py @@ -26,7 +26,7 @@ def common_suffix(its): def removesuffix(text, suffix): if suffix and text.endswith(suffix): - return text[:-len(suffix)] + return text[: -len(suffix)] return text @@ -46,7 +46,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True): ocr = removesuffix(gt, gt_suffix) + ocr_suffix gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True) - ocr_text = plain_extract(os.path.join(ocr_dir, ocr), include_filename_in_id=True) + ocr_text = plain_extract( + os.path.join(ocr_dir, ocr), include_filename_in_id=True + ) gt_words = words_normalized(gt_text) ocr_words = words_normalized(ocr_text) @@ -56,7 +58,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True): cer, n_characters = l_cer, l_n_characters else: # Rolling update - cer = (cer * n_characters + l_cer * l_n_characters) / (n_characters + l_n_characters) + cer = (cer * n_characters + l_cer * l_n_characters) / ( + n_characters + l_n_characters + ) n_characters = n_characters + l_n_characters # Compute WER diff --git a/qurator/dinglehopper/edit_distance.py b/qurator/dinglehopper/edit_distance.py index ad8eaf2..2120b80 100644 --- a/qurator/dinglehopper/edit_distance.py +++ b/qurator/dinglehopper/edit_distance.py @@ -17,6 +17,7 @@ def distance(seq1: list[str], seq2: list[str]): """ return Levenshtein.distance(seq1, seq2) + @multimethod def distance(s1: str, s2: str): """Compute the Levenshtein edit distance between two Unicode strings diff --git a/qurator/dinglehopper/extracted_text.py b/qurator/dinglehopper/extracted_text.py index 0ddebf5..19ad9c1 100644 --- a/qurator/dinglehopper/extracted_text.py +++ b/qurator/dinglehopper/extracted_text.py @@ -175,7 +175,7 @@ class ExtractedText: return self._grapheme_clusters else: clusters = [] - for seg in self.segments: + for seg in self.segments: # todo could there be cases where joiner is no grapheme cluster? clusters.extend(seg.grapheme_clusters + [self.joiner]) return clusters[:-1] @@ -242,7 +242,9 @@ class ExtractedText: def from_str(cls, text, normalization=Normalization.NFC_SBB): normalized_text = normalize(text, normalization) clusters = list(grapheme_clusters(normalized_text)) - return cls(None, None, None, normalized_text, clusters, normalization=normalization) + return cls( + None, None, None, normalized_text, clusters, normalization=normalization + ) def invert_dict(d): diff --git a/qurator/dinglehopper/ocr_files.py b/qurator/dinglehopper/ocr_files.py index 38190da..6384dfa 100644 --- a/qurator/dinglehopper/ocr_files.py +++ b/qurator/dinglehopper/ocr_files.py @@ -98,14 +98,18 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level): ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children) ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"])) - elif ET.QName(group.tag).localname in ["UnorderedGroup","UnorderedGroupIndexed"]: + elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]: ro_children = list(group) else: raise NotImplementedError - for ro_child in ro_children: - if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]: + if ET.QName(ro_child.tag).localname in [ + "OrderedGroup", + "OrderedGroupIndexed", + "UnorderedGroup", + "UnorderedGroupIndexed", + ]: regions.extend( extract_texts_from_reading_order_group( ro_child, tree, nsmap, textequiv_level @@ -139,7 +143,11 @@ def plain_extract(filename, include_filename_in_id=False): clusters = list(grapheme_clusters(normalized_text)) return ExtractedText( id_template.format(filename=os.path.basename(filename), no=no), - None, None, normalized_text, clusters) + None, + None, + normalized_text, + clusters, + ) with open(filename, "r") as f: return ExtractedText( @@ -147,7 +155,7 @@ def plain_extract(filename, include_filename_in_id=False): [make_segment(no, line) for no, line in enumerate(f.readlines())], "\n", None, - None + None, ) # XXX hardcoded SBB normalization diff --git a/qurator/dinglehopper/ocrd_cli.py b/qurator/dinglehopper/ocrd_cli.py index 7c513e6..9578a0a 100644 --- a/qurator/dinglehopper/ocrd_cli.py +++ b/qurator/dinglehopper/ocrd_cli.py @@ -33,7 +33,7 @@ class OcrdDinglehopperEvaluate(Processor): textequiv_level = self.parameter["textequiv_level"] gt_grp, ocr_grp = self.input_file_grp.split(",") - input_file_tuples = self.zip_input_files(on_error='abort') + input_file_tuples = self.zip_input_files(on_error="abort") for n, (gt_file, ocr_file) in enumerate(input_file_tuples): if not gt_file or not ocr_file: # file/page was not found in this group diff --git a/qurator/dinglehopper/word_error_rate.py b/qurator/dinglehopper/word_error_rate.py index 0976921..3b9ff5e 100644 --- a/qurator/dinglehopper/word_error_rate.py +++ b/qurator/dinglehopper/word_error_rate.py @@ -40,7 +40,6 @@ def words(s: str): if not word_break_patched: patch_word_break() - # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar def unwanted(c):