🎨 Reformat using Black

pull/81/head
Gerber, Mike 2 years ago
parent 2268f32a78
commit 0f0819512e

@ -32,7 +32,7 @@ def common_suffix(its):
def removesuffix(text, suffix): def removesuffix(text, suffix):
if suffix and text.endswith(suffix): if suffix and text.endswith(suffix):
return text[:-len(suffix)] return text[: -len(suffix)]
return text return text
@ -52,7 +52,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
ocr = removesuffix(gt, gt_suffix) + ocr_suffix ocr = removesuffix(gt, gt_suffix) + ocr_suffix
gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True) gt_text = plain_extract(os.path.join(gt_dir, gt), include_filename_in_id=True)
ocr_text = plain_extract(os.path.join(ocr_dir, ocr), include_filename_in_id=True) ocr_text = plain_extract(
os.path.join(ocr_dir, ocr), include_filename_in_id=True
)
# Compute CER # Compute CER
l_cer, l_n_characters = character_error_rate_n(gt_text, ocr_text) l_cer, l_n_characters = character_error_rate_n(gt_text, ocr_text)
@ -60,7 +62,9 @@ def process(gt_dir, ocr_dir, report_prefix, *, metrics=True):
cer, n_characters = l_cer, l_n_characters cer, n_characters = l_cer, l_n_characters
else: else:
# Rolling update # Rolling update
cer = (cer * n_characters + l_cer * l_n_characters) / (n_characters + l_n_characters) cer = (cer * n_characters + l_cer * l_n_characters) / (
n_characters + l_n_characters
)
n_characters = n_characters + l_n_characters n_characters = n_characters + l_n_characters
# Compute WER # Compute WER

@ -98,14 +98,18 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children) ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children)
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"])) ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
elif ET.QName(group.tag).localname in ["UnorderedGroup","UnorderedGroupIndexed"]: elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
ro_children = list(group) ro_children = list(group)
else: else:
raise NotImplementedError raise NotImplementedError
for ro_child in ro_children: for ro_child in ro_children:
if ET.QName(ro_child.tag).localname in ["OrderedGroup", "OrderedGroupIndexed", "UnorderedGroup", "UnorderedGroupIndexed"]: if ET.QName(ro_child.tag).localname in [
"OrderedGroup",
"OrderedGroupIndexed",
"UnorderedGroup",
"UnorderedGroupIndexed",
]:
regions.extend( regions.extend(
extract_texts_from_reading_order_group( extract_texts_from_reading_order_group(
ro_child, tree, nsmap, textequiv_level ro_child, tree, nsmap, textequiv_level
@ -139,7 +143,10 @@ def plain_extract(filename, include_filename_in_id=False):
[ [
ExtractedText( ExtractedText(
id_template.format(filename=os.path.basename(filename), no=no), id_template.format(filename=os.path.basename(filename), no=no),
None, None, normalize_sbb(line)) None,
None,
normalize_sbb(line),
)
for no, line in enumerate(f.readlines()) for no, line in enumerate(f.readlines())
], ],
"\n", "\n",

@ -33,7 +33,7 @@ class OcrdDinglehopperEvaluate(Processor):
textequiv_level = self.parameter["textequiv_level"] textequiv_level = self.parameter["textequiv_level"]
gt_grp, ocr_grp = self.input_file_grp.split(",") gt_grp, ocr_grp = self.input_file_grp.split(",")
input_file_tuples = self.zip_input_files(on_error='abort') input_file_tuples = self.zip_input_files(on_error="abort")
for n, (gt_file, ocr_file) in enumerate(input_file_tuples): for n, (gt_file, ocr_file) in enumerate(input_file_tuples):
if not gt_file or not ocr_file: if not gt_file or not ocr_file:
# file/page was not found in this group # file/page was not found in this group

@ -15,7 +15,7 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.skipif(sys.platform == 'win32', reason="only on unix") @pytest.mark.skipif(sys.platform == "win32", reason="only on unix")
def test_ocrd_cli(tmp_path): def test_ocrd_cli(tmp_path):
"""Test OCR-D interface""" """Test OCR-D interface"""

@ -42,10 +42,8 @@ def words(s: str):
if not word_break_patched: if not word_break_patched:
patch_word_break() patch_word_break()
# Check if c is an unwanted character, i.e. whitespace, punctuation, or similar # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar
def unwanted(c): def unwanted(c):
# See https://www.fileformat.info/info/unicode/category/index.htm # See https://www.fileformat.info/info/unicode/category/index.htm
# and https://unicodebook.readthedocs.io/unicode.html#categories # and https://unicodebook.readthedocs.io/unicode.html#categories
unwanted_categories = "O", "M", "P", "Z", "S" unwanted_categories = "O", "M", "P", "Z", "S"

@ -4,7 +4,7 @@ from setuptools import find_packages, setup
with open("requirements.txt") as fp: with open("requirements.txt") as fp:
install_requires = fp.read() install_requires = fp.read()
with open('requirements-dev.txt') as fp: with open("requirements-dev.txt") as fp:
tests_require = fp.read() tests_require = fp.read()
setup( setup(

Loading…
Cancel
Save