apply black

pull/72/head
Max Bachmann 2 years ago
parent 01571f23b7
commit 22c3817f45

@ -9,7 +9,9 @@ from .extracted_text import ExtractedText
@multimethod @multimethod
def character_error_rate_n(reference: list[str], compared: list[str]) -> Tuple[float, int]: def character_error_rate_n(
reference: list[str], compared: list[str]
) -> Tuple[float, int]:
""" """
Compute character error rate. Compute character error rate.
@ -39,7 +41,9 @@ def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
def character_error_rate_n( def character_error_rate_n(
reference: ExtractedText, compared: ExtractedText reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]: ) -> Tuple[float, int]:
return character_error_rate_n(reference.grapheme_clusters, compared.grapheme_clusters) return character_error_rate_n(
reference.grapheme_clusters, compared.grapheme_clusters
)
def character_error_rate(reference, compared) -> float: def character_error_rate(reference, compared) -> float:

@ -46,7 +46,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
ocr = removesuffix(gt, gt_suffix) + ocr_suffix ocr = removesuffix(gt, gt_suffix) + ocr_suffix
gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True) gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True)
ocr_text = plain_extract(os.path.join(ocr_dir, ocr), include_filename_in_id=True) ocr_text = plain_extract(
os.path.join(ocr_dir, ocr), include_filename_in_id=True
)
gt_words = words_normalized(gt_text) gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text) ocr_words = words_normalized(ocr_text)
@ -56,7 +58,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
cer, n_characters = l_cer, l_n_characters cer, n_characters = l_cer, l_n_characters
else: else:
# Rolling update # Rolling update
cer = (cer * n_characters + l_cer * l_n_characters) / (n_characters + l_n_characters) cer = (cer * n_characters + l_cer * l_n_characters) / (
n_characters + l_n_characters
)
n_characters = n_characters + l_n_characters n_characters = n_characters + l_n_characters
# Compute WER # Compute WER

@ -17,6 +17,7 @@ def distance(seq1: list[str], seq2: list[str]):
""" """
return Levenshtein.distance(seq1, seq2) return Levenshtein.distance(seq1, seq2)
@multimethod @multimethod
def distance(s1: str, s2: str): def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings

@ -242,7 +242,9 @@ class ExtractedText:
def from_str(cls, text, normalization=Normalization.NFC_SBB): def from_str(cls, text, normalization=Normalization.NFC_SBB):
normalized_text = normalize(text, normalization) normalized_text = normalize(text, normalization)
clusters = list(grapheme_clusters(normalized_text)) clusters = list(grapheme_clusters(normalized_text))
return cls(None, None, None, normalized_text, clusters, normalization=normalization) return cls(
None, None, None, normalized_text, clusters, normalization=normalization
)
def invert_dict(d): def invert_dict(d):

@ -103,9 +103,13 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
else: else:
raise NotImplementedError raise NotImplementedError
for ro_child in ro_children: for ro_child in ro_children:
if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]: if ET.QName(ro_child.tag).localname in [
"OrderedGroup",
"OrderedGroupIndexed",
"UnorderedGroup",
"UnorderedGroupIndexed",
]:
regions.extend( regions.extend(
extract_texts_from_reading_order_group( extract_texts_from_reading_order_group(
ro_child, tree, nsmap, textequiv_level ro_child, tree, nsmap, textequiv_level
@ -139,7 +143,11 @@ def plain_extract(filename, include_filename_in_id=False):
clusters = list(grapheme_clusters(normalized_text)) clusters = list(grapheme_clusters(normalized_text))
return ExtractedText( return ExtractedText(
id_template.format(filename=os.path.basename(filename), no=no), id_template.format(filename=os.path.basename(filename), no=no),
None, None, normalized_text, clusters) None,
None,
normalized_text,
clusters,
)
with open(filename, "r") as f: with open(filename, "r") as f:
return ExtractedText( return ExtractedText(
@ -147,7 +155,7 @@ def plain_extract(filename, include_filename_in_id=False):
[make_segment(no, line) for no, line in enumerate(f.readlines())], [make_segment(no, line) for no, line in enumerate(f.readlines())],
"\n", "\n",
None, None,
None None,
) )
# XXX hardcoded SBB normalization # XXX hardcoded SBB normalization

@ -33,7 +33,7 @@ class OcrdDinglehopperEvaluate(Processor):
textequiv_level = self.parameter["textequiv_level"] textequiv_level = self.parameter["textequiv_level"]
gt_grp, ocr_grp = self.input_file_grp.split(",") gt_grp, ocr_grp = self.input_file_grp.split(",")
input_file_tuples = self.zip_input_files(on_error='abort') input_file_tuples = self.zip_input_files(on_error="abort")
for n, (gt_file, ocr_file) in enumerate(input_file_tuples): for n, (gt_file, ocr_file) in enumerate(input_file_tuples):
if not gt_file or not ocr_file: if not gt_file or not ocr_file:
# file/page was not found in this group # file/page was not found in this group

@ -40,7 +40,6 @@ def words(s: str):
if not word_break_patched: if not word_break_patched:
patch_word_break() patch_word_break()
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
def unwanted(c): def unwanted(c):

Loading…
Cancel
Save