Merge branch 'feat/display-segment-id'

pull/38/head
Gerber, Mike 4 years ago
commit f50591abac

@ -2,11 +2,10 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" /> <orderEntry type="jdk" jdkName="Python 3.7 (dinglehopper-github)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
<component name="TestRunnerService"> <component name="TestRunnerService">
<option name="projectConfiguration" value="pytest" />
<option name="PROJECT_TEST_RUNNER" value="pytest" /> <option name="PROJECT_TEST_RUNNER" value="pytest" />
</component> </component>
</module> </module>

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (dinglehopper)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (dinglehopper-github)" project-jdk-type="Python SDK" />
</project> </project>

@ -28,16 +28,16 @@ def seq_align(s1, s2):
if o: if o:
if o[0] == 'insert': if o[0] == 'insert':
yield (None, s2[j]) yield None, s2[j]
j += 1 j += 1
elif o[0] == 'delete': elif o[0] == 'delete':
yield (s1[i], None) yield s1[i], None
i += 1 i += 1
elif o[0] == 'replace': elif o[0] == 'replace':
yield (s1[i], s2[j]) yield s1[i], s2[j]
i += 1 i += 1
j += 1 j += 1
else: else:
yield (s1[i], s2[j]) yield s1[i], s2[j]
i += 1 i += 1
j += 1 j += 1

@ -3,17 +3,21 @@ from __future__ import division
import unicodedata import unicodedata
from typing import Tuple from typing import Tuple
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from qurator.dinglehopper.edit_distance import distance from .edit_distance import distance
from .extracted_text import ExtractedText
def character_error_rate_n(reference, compared) -> Tuple[float, int]: @multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
""" """
Compute character error rate. Compute character error rate.
:return: character error rate and length of the reference :return: character error rate and length of the reference
""" """
d = distance(reference, compared) d = distance(reference, compared)
n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference)))) n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference))))
@ -26,6 +30,11 @@ def character_error_rate_n(reference, compared) -> Tuple[float, int]:
# XXX Should we really count newlines here? # XXX Should we really count newlines here?
@multimethod
def character_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return character_error_rate_n(reference.text, compared.text)
def character_error_rate(reference, compared) -> float: def character_error_rate(reference, compared) -> float:
""" """
Compute character error rate. Compute character error rate.

@ -3,16 +3,20 @@ import os
import click import click
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from markupsafe import escape from markupsafe import escape
from uniseg.graphemecluster import grapheme_clusters
from .character_error_rate import character_error_rate_n
from .word_error_rate import word_error_rate_n, words_normalized
from .align import seq_align
from .extracted_text import ExtractedText
from .ocr_files import extract
from qurator.dinglehopper import *
def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none):
def gen_diff_report(gt_things, ocr_things, css_prefix, joiner, none, align):
gtx = '' gtx = ''
ocrx = '' ocrx = ''
def format_thing(t, css_classes=None): def format_thing(t, css_classes=None, id_=None):
if t is None: if t is None:
html_t = none html_t = none
css_classes += ' ellipsis' css_classes += ' ellipsis'
@ -21,19 +25,51 @@ def gen_diff_report(gt_things, ocr_things, css_prefix, joiner, none, align):
else: else:
html_t = escape(t) html_t = escape(t)
html_custom_attrs = ""
# Set Bootstrap tooltip to the segment id
if id_:
html_custom_attrs += 'data-toggle="tooltip" title="{}"'.format(id_)
if css_classes: if css_classes:
return '<span class="{css_classes}">{html_t}</span>'.format(css_classes=css_classes, html_t=html_t) return '<span class="{css_classes}" {html_custom_attrs}>{html_t}</span>'.format(css_classes=css_classes, html_t=html_t, html_custom_attrs=html_custom_attrs)
else: else:
return '{html_t}'.format(html_t=html_t) return '{html_t}'.format(html_t=html_t)
for k, (g, o) in enumerate(align(gt_things, ocr_things)): if isinstance(gt_in, ExtractedText):
if g == o: if not isinstance(ocr_in, ExtractedText):
css_classes = None raise TypeError()
else: # XXX splitting should be done in ExtractedText
gt_things = list(grapheme_clusters(gt_in.text))
ocr_things = list(grapheme_clusters(ocr_in.text))
else:
gt_things = gt_in
ocr_things = ocr_in
g_pos = 0
o_pos = 0
for k, (g, o) in enumerate(seq_align(gt_things, ocr_things)):
css_classes = None
gt_id = None
ocr_id = None
if g != o:
css_classes = '{css_prefix}diff{k} diff'.format(css_prefix=css_prefix, k=k) css_classes = '{css_prefix}diff{k} diff'.format(css_prefix=css_prefix, k=k)
if isinstance(gt_in, ExtractedText):
gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None
ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None
# Deletions and inserts only produce one id + None, UI must
# support this, i.e. display for the one id produced
gtx += joiner + format_thing(g, css_classes, gt_id)
ocrx += joiner + format_thing(o, css_classes, ocr_id)
if g is not None:
g_pos += len(g)
if o is not None:
o_pos += len(o)
gtx += joiner + format_thing(g, css_classes)
ocrx += joiner + format_thing(o, css_classes)
return \ return \
''' '''
@ -51,20 +87,17 @@ def process(gt, ocr, report_prefix, *, metrics=True):
Click on a wrapper. Click on a wrapper.
""" """
gt_text = text(gt) gt_text = extract(gt)
ocr_text = text(ocr) ocr_text = extract(ocr)
gt_text = substitute_equivalences(gt_text)
ocr_text = substitute_equivalences(ocr_text)
cer, n_characters = character_error_rate_n(gt_text, ocr_text) cer, n_characters = character_error_rate_n(gt_text, ocr_text)
wer, n_words = word_error_rate_n(gt_text, ocr_text) wer, n_words = word_error_rate_n(gt_text, ocr_text)
char_diff_report = gen_diff_report(gt_text, ocr_text, css_prefix='c', joiner='', none='·', align=align) char_diff_report = gen_diff_report(gt_text, ocr_text, css_prefix='c', joiner='', none='·')
gt_words = words_normalized(gt_text) gt_words = words_normalized(gt_text)
ocr_words = words_normalized(ocr_text) ocr_words = words_normalized(ocr_text)
word_diff_report = gen_diff_report(gt_words, ocr_words, css_prefix='w', joiner=' ', none='', align=seq_align) word_diff_report = gen_diff_report(gt_words, ocr_words, css_prefix='w', joiner=' ', none='')
def json_float(value): def json_float(value):
"""Convert a float value to an JSON float. """Convert a float value to an JSON float.

@ -5,8 +5,11 @@ from functools import partial, lru_cache
from typing import Sequence, Tuple from typing import Sequence, Tuple
import numpy as np import numpy as np
from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters from uniseg.graphemecluster import grapheme_clusters
from .extracted_text import ExtractedText
def levenshtein_matrix(seq1: Sequence, seq2: Sequence): def levenshtein_matrix(seq1: Sequence, seq2: Sequence):
"""Compute the matrix commonly computed to produce the Levenshtein distance. """Compute the matrix commonly computed to produce the Levenshtein distance.
@ -69,15 +72,21 @@ def levenshtein_matrix_cache_clear():
_levenshtein_matrix.cache_clear() _levenshtein_matrix.cache_clear()
def distance(s1, s2): @multimethod
def distance(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings """Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme
clusters. This should be the correct way to compare two Unicode strings. clusters. This should be the correct way to compare two Unicode strings.
""" """
s1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1))) seq1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1)))
s2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2))) seq2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2)))
return levenshtein(s1, s2) return levenshtein(seq1, seq2)
@multimethod
def distance(s1: ExtractedText, s2: ExtractedText):
return distance(s1.text, s2.text)
def seq_editops(seq1, seq2): def seq_editops(seq1, seq2):
@ -116,7 +125,11 @@ def seq_editops(seq1, seq2):
def editops(word1, word2): def editops(word1, word2):
# XXX Note that this returns indices to the _grapheme clusters_, not characters! """
Return sequence of edit operations transforming one string to another.
Note that this returns indices to the _grapheme clusters_, not characters!
"""
word1 = list(grapheme_clusters(unicodedata.normalize('NFC', word1))) word1 = list(grapheme_clusters(unicodedata.normalize('NFC', word1)))
word2 = list(grapheme_clusters(unicodedata.normalize('NFC', word2))) word2 = list(grapheme_clusters(unicodedata.normalize('NFC', word2)))
return seq_editops(word1, word2) return seq_editops(word1, word2)

@ -0,0 +1,118 @@
import enum
import re
import unicodedata
from contextlib import suppress
from itertools import repeat
from typing import Optional
import attr
from .substitute_equivalences import substitute_equivalences
class Normalization(enum.Enum):
NFC = 1
NFC_MUFI = 2 # TODO
NFC_SBB = 3
def normalize(text, normalization):
if normalization == Normalization.NFC:
return unicodedata.normalize('NFC', text)
if normalization == Normalization.NFC_MUFI:
raise NotImplementedError()
if normalization == Normalization.NFC_SBB:
return substitute_equivalences(text)
else:
raise ValueError()
# XXX hack
def normalize_sbb(t):
return normalize(t, Normalization.NFC_SBB)
@attr.s(frozen=True)
class ExtractedText:
"""
Extracted text
Objects of this class are guaranteed to be a. always in their normalization and
b. in NFC.
"""
segment_id = attr.ib(type=Optional[str])
@segment_id.validator
def check(self, _, value):
if value is None:
return
if not re.match(r'[\w\d_-]+', value):
raise ValueError('Malformed segment id "{}"'.format(value))
# An object contains either
# a. _text itself
# b. or segments (ExtractedText) and a joiner
segments = attr.ib(type=Optional[list], converter=attr.converters.optional(list))
joiner = attr.ib(type=Optional[str])
_text = attr.ib(type=Optional[str])
@segments.validator
def check(self, _, value):
if value is not None and self._text is not None:
raise ValueError("Can't have both segments and text")
@_text.validator
def check(self, _, value):
if value is not None and self.segments is not None:
raise ValueError("Can't have both segments and text")
if value is not None and unicodedata.normalize('NFC', value) != value:
raise ValueError('String "{}" is not in NFC.'.format(value))
if value is not None and normalize(value, self.normalization) != value:
raise ValueError('String "{}" is not normalized.'.format(value))
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)
@property
def text(self):
if self._text is not None:
if self._text == '':
return None
else:
return self._text
else:
return self.joiner.join(s.text for s in self.segments)
_segment_id_for_pos = None
def segment_id_for_pos(self, pos):
# Calculate segment ids once, on the first call
if not self._segment_id_for_pos:
segment_id_for_pos = []
for s in self.segments:
segment_id_for_pos.extend(repeat(s.segment_id, len(s.text)))
segment_id_for_pos.extend(repeat(None, len(self.joiner)))
segment_id_for_pos = segment_id_for_pos[:-len(self.joiner)]
# This is frozen, so we have to jump through the hoop:
object.__setattr__(self, '_segment_id_for_pos', segment_id_for_pos)
assert self._segment_id_for_pos
return self._segment_id_for_pos[pos]
@classmethod
def from_text_segment(cls, text_segment, nsmap):
"""Build an ExtractedText from a PAGE content text element"""
segment_id = text_segment.attrib['id']
segment_text = None
with suppress(AttributeError):
segment_text = text_segment.find('./page:TextEquiv/page:Unicode', namespaces=nsmap).text
segment_text = segment_text or ''
segment_text = normalize_sbb(segment_text) # FIXME hardcoded SBB normalization
segment_text = segment_text or ''
return cls(segment_id, None, None, segment_text)
@classmethod
def from_str(cls, text, normalization=Normalization.NFC_SBB):
normalized_text = normalize(text, normalization)
return cls(None, None, None, normalized_text, normalization=normalization)

@ -1,14 +1,16 @@
from __future__ import division, print_function from __future__ import division, print_function
from typing import Generator
from warnings import warn from warnings import warn
from lxml import etree as ET
import sys import sys
from lxml import etree as ET
from lxml.etree import XMLSyntaxError from lxml.etree import XMLSyntaxError
from .extracted_text import ExtractedText, normalize_sbb
def alto_namespace(tree):
def alto_namespace(tree: ET.ElementTree) -> str:
"""Return the ALTO namespace used in the given ElementTree. """Return the ALTO namespace used in the given ElementTree.
This relies on the assumption that, in any given ALTO file, the root element has the local name "alto". We do not This relies on the assumption that, in any given ALTO file, the root element has the local name "alto". We do not
@ -21,17 +23,22 @@ def alto_namespace(tree):
raise ValueError('Not an ALTO tree') raise ValueError('Not an ALTO tree')
def alto_text(tree): def alto_extract_lines(tree: ET.ElementTree) -> Generator[ExtractedText, None, None]:
"""Extract text from the given ALTO ElementTree."""
nsmap = {'alto': alto_namespace(tree)} nsmap = {'alto': alto_namespace(tree)}
for line in tree.iterfind('.//alto:TextLine', namespaces=nsmap):
line_id = line.attrib.get('ID')
line_text = ' '.join(string.attrib.get('CONTENT') for string in line.iterfind('alto:String', namespaces=nsmap))
yield ExtractedText(line_id, None, None, normalize_sbb(line_text))
# FIXME hardcoded SBB normalization
def alto_extract(tree: ET.ElementTree()) -> ExtractedText:
"""Extract text from the given ALTO ElementTree."""
return ExtractedText(None, list(alto_extract_lines(tree)), '\n', None)
lines = (
' '.join(string.attrib.get('CONTENT') for string in line.iterfind('alto:String', namespaces=nsmap))
for line in tree.iterfind('.//alto:TextLine', namespaces=nsmap))
text_ = '\n'.join(lines)
return text_ def alto_text(tree):
return alto_extract(tree).text
def page_namespace(tree): def page_namespace(tree):
@ -47,18 +54,12 @@ def page_namespace(tree):
raise ValueError('Not a PAGE tree') raise ValueError('Not a PAGE tree')
def page_text(tree): def page_extract(tree):
"""Extract text from the given PAGE content ElementTree.""" """Extract text from the given PAGE content ElementTree."""
nsmap = {'page': page_namespace(tree)} nsmap = {'page': page_namespace(tree)}
def region_text(region): regions = []
try:
return region.find('./page:TextEquiv/page:Unicode', namespaces=nsmap).text
except AttributeError:
return None
region_texts = []
reading_order = tree.find('.//page:ReadingOrder', namespaces=nsmap) reading_order = tree.find('.//page:ReadingOrder', namespaces=nsmap)
if reading_order is not None: if reading_order is not None:
for group in reading_order.iterfind('./*', namespaces=nsmap): for group in reading_order.iterfind('./*', namespaces=nsmap):
@ -68,39 +69,56 @@ def page_text(tree):
region_id = region_ref_indexed.attrib['regionRef'] region_id = region_ref_indexed.attrib['regionRef']
region = tree.find('.//page:TextRegion[@id="%s"]' % region_id, namespaces=nsmap) region = tree.find('.//page:TextRegion[@id="%s"]' % region_id, namespaces=nsmap)
if region is not None: if region is not None:
region_texts.append(region_text(region)) regions.append(ExtractedText.from_text_segment(region, nsmap))
else: else:
warn('Not a TextRegion: "%s"' % region_id) warn('Not a TextRegion: "%s"' % region_id)
else: else:
raise NotImplementedError raise NotImplementedError
else: else:
for region in tree.iterfind('.//page:TextRegion', namespaces=nsmap): for region in tree.iterfind('.//page:TextRegion', namespaces=nsmap):
region_texts.append(region_text(region)) regions.append(ExtractedText.from_text_segment(region, nsmap))
# XXX Does a file have to have regions etc.? region vs lines etc.
# Filter empty region texts # Filter empty region texts
region_texts = (t for t in region_texts if t) regions = [r for r in regions if r.text is not None]
text_ = '\n'.join(region_texts) return ExtractedText(None, regions, '\n', None)
return text_
def page_text(tree):
return page_extract(tree).text
def text(filename):
"""Read the text from the given file. def plain_extract(filename):
with open(filename, 'r') as f:
return ExtractedText(
None,
[ExtractedText('line %d' % no, None, None, line) for no, line in enumerate(f.readlines())],
'\n',
None
)
def plain_text(filename):
return plain_extract(filename).text
def extract(filename):
"""Extract the text from the given file.
Supports PAGE, ALTO and falls back to plain text. Supports PAGE, ALTO and falls back to plain text.
""" """
try: try:
tree = ET.parse(filename) tree = ET.parse(filename)
except XMLSyntaxError: except XMLSyntaxError:
with open(filename, 'r') as f: return plain_extract(filename)
return f.read()
try: try:
return page_text(tree) return page_extract(tree)
except ValueError: except ValueError:
return alto_text(tree) return alto_extract(tree)
def text(filename):
return extract(filename).text
if __name__ == '__main__': if __name__ == '__main__':

@ -7,8 +7,8 @@ from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor
from ocrd_utils import getLogger, make_file_id, assert_file_grp_cardinality from ocrd_utils import getLogger, make_file_id, assert_file_grp_cardinality
from pkg_resources import resource_string from pkg_resources import resource_string
from qurator.dinglehopper.cli import process as cli_process from .cli import process as cli_process
from qurator.dinglehopper.edit_distance import levenshtein_matrix_cache_clear from .edit_distance import levenshtein_matrix_cache_clear
OCRD_TOOL = json.loads(resource_string(__name__, 'ocrd-tool.json').decode('utf8')) OCRD_TOOL = json.loads(resource_string(__name__, 'ocrd-tool.json').decode('utf8'))

@ -1,21 +1,15 @@
import unicodedata import unicodedata
def substitute_equivalences(s): def unjoin_ligatures(s):
"""Unjoin ligatures, i.e. ff becomes ff."""
# These are for OCR-D GT vs Tesseract frk vs Calamari GT4HistOCR
# It might make sense to use different rules for GT and for the different OCR
equivalences = { equivalences = {
'': 'ü',
'': 'ſſ', '': 'ſſ',
"\ueba7": 'ſſi', # MUFI: LATIN SMALL LIGATURE LONG S LONG S I "\ueba7": 'ſſi', # MUFI: LATIN SMALL LIGATURE LONG S LONG S I
'': 'ä',
'': 'ch', '': 'ch',
'==': '', # → en-dash
'': '', # em-dash → en-dash
'': 'ck', '': 'ck',
'': 'll', '': 'll',
'': 'ö',
'': 'ſi', '': 'ſi',
'': 'ſt', '': 'ſt',
'': 'fi', '': 'fi',
@ -23,12 +17,7 @@ def substitute_equivalences(s):
'': 'fl', '': 'fl',
'': 'ffi', '': 'ffi',
'': 'ct', '': 'ct',
'': '\'',
'': '-',
'': 'tz', # MUFI: LATIN SMALL LIGATURE TZ '': 'tz', # MUFI: LATIN SMALL LIGATURE TZ
'': 'ä', # LATIN SMALL LETTER A, COMBINING LATIN SMALL LETTER E
'': 'ö', # LATIN SMALL LETTER O, COMBINING LATIN SMALL LETTER E
'': 'ü', # LATIN SMALL LETTER U, COMBINING LATIN SMALL LETTER E
'\uf532': 'as', # eMOP: Latin small ligature as '\uf532': 'as', # eMOP: Latin small ligature as
'\uf533': 'is', # eMOP: Latin small ligature is '\uf533': 'is', # eMOP: Latin small ligature is
'\uf534': 'us', # eMOP: Latin small ligature us '\uf534': 'us', # eMOP: Latin small ligature us
@ -37,10 +26,32 @@ def substitute_equivalences(s):
'\uE8BF': 'q&', # MUFI: LATIN SMALL LETTER Q LIGATED WITH FINAL ET XXX How to replace this correctly? '\uE8BF': 'q&', # MUFI: LATIN SMALL LETTER Q LIGATED WITH FINAL ET XXX How to replace this correctly?
'\uEBA5': 'ſp', # MUFI: LATIN SMALL LIGATURE LONG S P '\uEBA5': 'ſp', # MUFI: LATIN SMALL LIGATURE LONG S P
'': 'st', # U+FB06 LATIN SMALL LIGATURE ST '': 'st', # U+FB06 LATIN SMALL LIGATURE ST
}
s = unicodedata.normalize('NFC', s)
for fr, to in equivalences.items():
s = s.replace(fr, to)
return s
def substitute_equivalences(s):
# These are for OCR-D GT vs Tesseract frk vs Calamari GT4HistOCR
# It might make sense to use different rules for GT and for the different OCR
equivalences = {
'': 'ü',
'': 'ä',
'==': '', # → en-dash
'': '', # em-dash → en-dash
'': 'ö',
'': '\'',
'': '-',
'': 'ä', # LATIN SMALL LETTER A, COMBINING LATIN SMALL LETTER E
'': 'ö', # LATIN SMALL LETTER O, COMBINING LATIN SMALL LETTER E
'': 'ü', # LATIN SMALL LETTER U, COMBINING LATIN SMALL LETTER E
'\uF50E': '' # U+F50E LATIN SMALL LETTER Q WITH ACUTE ACCENT '\uF50E': '' # U+F50E LATIN SMALL LETTER Q WITH ACUTE ACCENT
} }
s = unicodedata.normalize('NFC', s) s = unicodedata.normalize('NFC', s)
s = unjoin_ligatures(s)
for fr, to in equivalences.items(): for fr, to in equivalences.items():
s = s.replace(fr, to) s = s.replace(fr, to)
return s return s

@ -1,14 +1,15 @@
function find_diff_class(classes) { function find_diff_class(classes) {
return classes.split(/\s+/).find(x => x.match(/.diff\d.*/)); return $('.' + classes.split(/\s+/).find(x => x.match(/.diff\d.*/)));
} }
$(document).ready(function() { $(document).ready(function() {
/* Enable Bootstrap tooltips */
$('[data-toggle="tooltip"]').tooltip();
$('.diff').mouseover(function() { $('.diff').mouseover(function() {
let c = find_diff_class($(this).attr('class')) find_diff_class($(this).attr('class')).addClass('diff-highlight');
$('.' + c).addClass('diff-highlight')
}); });
$('.diff').mouseout(function() { $('.diff').mouseout(function() {
let c = find_diff_class($(this).attr('class')) find_diff_class($(this).attr('class')).removeClass('diff-highlight');
$('.' + c).removeClass('diff-highlight')
}); });
}); });

@ -0,0 +1,68 @@
import unicodedata
import pytest
from uniseg.graphemecluster import grapheme_clusters
from collections import namedtuple
from .. import seq_align, ExtractedText
def test_text():
test1 = ExtractedText(None, [
ExtractedText('s0', None, None, 'foo'),
ExtractedText('s1', None, None, 'bar'),
ExtractedText('s2', None, None, 'bazinga')
], ' ', None)
assert test1.text == 'foo bar bazinga'
assert test1.segment_id_for_pos(0) == 's0'
assert test1.segment_id_for_pos(3) is None
assert test1.segment_id_for_pos(10) == 's2'
def test_normalization_check():
with pytest.raises(ValueError, match=r'.*is not in NFC.*'):
ExtractedText('foo', None, None, unicodedata.normalize('NFD', 'Schlyñ'))
assert ExtractedText('foo', None, None, unicodedata.normalize('NFC', 'Schlyñ'))
AlignmentElement = namedtuple('AlignmentElement', 'left right left_id right_id')
def test_align():
"""
Test aligning by character while retaining segment id info
The difficulty here is that aligning should work on grapheme clusters,
not Python characters.
"""
test1 = ExtractedText(None, [
ExtractedText('s0', None, None, 'foo'),
ExtractedText('s1', None, None, 'bar'),
ExtractedText('s2', None, None, 'batzinga')
], ' ', None)
test2 = ExtractedText(None, [
ExtractedText('x0', None, None, 'foo'),
ExtractedText('x1', None, None, 'bar'),
ExtractedText('x2', None, None, '.'), # extra .
ExtractedText('x3', None, None, 'bazim̃ga'), # deletion + different grapheme cluster, m̃ also is two Python characters
], ' ', None)
left_pos = 0; right_pos = 0; alignment = []
for left, right in seq_align(grapheme_clusters(test1.text), grapheme_clusters(test2.text)):
left_id = test1.segment_id_for_pos(left_pos) if left is not None else None
right_id = test2.segment_id_for_pos(right_pos) if right is not None else None
el = AlignmentElement(left, right, left_id, right_id)
alignment.append(el)
if left is not None:
left_pos += len(left)
if right is not None:
right_pos += len(right)
print('test1: {}'.format(test1.text))
print('test2: {}'.format(test2.text))
assert alignment[0] == ('f', 'f', 's0', 'x0')
assert alignment[8] == (None, '.', None, 'x2')
assert alignment[12] == ('t', None, 's2', None)
assert alignment[15] == ('n', '', 's2', 'x3')

@ -78,7 +78,8 @@ def test_lines():
def test_lines_similar(): def test_lines_similar():
"""Test comparing list of lines while using a "weaker equivalence". """
Test comparing list of lines while using a "weaker equivalence".
This mainly serves as documentation. This mainly serves as documentation.
""" """
@ -88,7 +89,14 @@ def test_lines_similar():
self._string = string self._string = string
def __eq__(self, other): def __eq__(self, other):
return distance(self._string, other._string) < 2 # XXX NOT the final version # Just an example!
min_len = min(len(self._string), len(other._string))
if min_len > 0:
normalized_distance = distance(self._string, other._string)/min_len
similar = normalized_distance < 0.1
else:
similar = False
return similar
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
@ -106,3 +114,6 @@ def test_lines_similar():
left, right = unzip(result) left, right = unzip(result)
assert list(left) == [SimilarString('This is a line.'), SimilarString('This is another'), None, SimilarString('And the last line')] assert list(left) == [SimilarString('This is a line.'), SimilarString('This is another'), None, SimilarString('And the last line')]
assert list(right) == [SimilarString('This is a ljne.'), SimilarString('This is another'), SimilarString('J u n k'), SimilarString('And the last line')] assert list(right) == [SimilarString('This is a ljne.'), SimilarString('This is another'), SimilarString('J u n k'), SimilarString('And the last line')]
# Test __eq__ (i.e. is it a substitution or a similar string?)
assert list(left)[0] == list(right)[0]

@ -13,11 +13,15 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
@pytest.mark.integration @pytest.mark.integration
def test_align_page_files(): def test_align_page_files():
# In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi.
# → 4 elements in the alignment should be different. # → 2 elements in the alignment should be different, the ligature is
# (currently) not counted due to normalization.
# NOTE: In this example, it doesn't matter that we work with "characters", not grapheme clusters. # NOTE: In this example, it doesn't matter that we work with "characters", not grapheme clusters.
gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml'))) gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml')))
ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml'))) ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml')))
result = list(align(gt, ocr)) result = list(align(gt, ocr))
assert sum(left != right for left, right in result) == 4 for left, right in result:
if left != right:
print(left, right)
assert sum(left != right for left, right in result) == 2

@ -4,6 +4,7 @@ import os
import pytest import pytest
from lxml import etree as ET from lxml import etree as ET
from uniseg.graphemecluster import grapheme_clusters
from .. import character_error_rate, page_text, alto_text from .. import character_error_rate, page_text, alto_text
@ -13,9 +14,14 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
@pytest.mark.integration @pytest.mark.integration
def test_character_error_rate_between_page_files(): def test_character_error_rate_between_page_files():
# In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi.
# The fi ligature does not count.
gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml'))) gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml')))
ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml'))) ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml')))
assert character_error_rate(gt, ocr) == 4/(470 + 1 + 311) # 2 TextRegions, 1 \n
gt_len = len(list(grapheme_clusters(gt)))
expected_cer = 2/gt_len
assert character_error_rate(gt, ocr) == expected_cer
@pytest.mark.integration @pytest.mark.integration

@ -1,4 +1,3 @@
import os
import json import json
import pytest import pytest
@ -10,14 +9,17 @@ from ..cli import process
def test_cli_json(tmp_path): def test_cli_json(tmp_path):
"""Test that the cli/process() yields a loadable JSON report""" """Test that the cli/process() yields a loadable JSON report"""
# XXX Path.__str__() is necessary for Python 3.5
with working_directory(str(tmp_path)): with working_directory(str(tmp_path)):
with open('gt.txt', 'w') as gtf: with open('gt.txt', 'w') as gtf:
gtf.write('AAAAA') gtf.write('AAAAA')
with open('ocr.txt', 'w') as ocrf: with open('ocr.txt', 'w') as ocrf:
ocrf.write('AAAAB') ocrf.write('AAAAB')
with open('gt.txt', 'r') as gtf:
print(gtf.read())
process('gt.txt', 'ocr.txt', 'report') process('gt.txt', 'ocr.txt', 'report')
with open('report.json', 'r') as jsonf:
print(jsonf.read())
with open('report.json', 'r') as jsonf: with open('report.json', 'r') as jsonf:
j = json.load(jsonf) j = json.load(jsonf)
assert j['cer'] == pytest.approx(0.2) assert j['cer'] == pytest.approx(0.2)
@ -26,7 +28,6 @@ def test_cli_json(tmp_path):
def test_cli_json_cer_is_infinity(tmp_path): def test_cli_json_cer_is_infinity(tmp_path):
"""Test that the cli/process() yields a loadable JSON report when CER == inf""" """Test that the cli/process() yields a loadable JSON report when CER == inf"""
# XXX Path.__str__() is necessary for Python 3.5
with working_directory(str(tmp_path)): with working_directory(str(tmp_path)):
with open('gt.txt', 'w') as gtf: with open('gt.txt', 'w') as gtf:
gtf.write('') # Empty to yield CER == inf gtf.write('') # Empty to yield CER == inf

@ -13,9 +13,11 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
@pytest.mark.integration @pytest.mark.integration
def test_distance_between_page_files(): def test_distance_between_page_files():
# In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi.
# Due to normalization, we don't count the ligature.
# → 2 differences
gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml'))) gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml')))
ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml'))) ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml')))
assert distance(gt, ocr) == 4 assert distance(gt, ocr) == 2
@pytest.mark.integration @pytest.mark.integration

@ -1,12 +1,10 @@
import os import os
import re
import shutil import shutil
import json import json
import sys import sys
from pathlib import Path from pathlib import Path
from click.testing import CliRunner from click.testing import CliRunner
import pytest
from .util import working_directory from .util import working_directory
@ -18,8 +16,6 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
def test_ocrd_cli(tmp_path): def test_ocrd_cli(tmp_path):
"""Test OCR-D interface""" """Test OCR-D interface"""
# XXX Path.str() is necessary for Python 3.5
# Copy test workspace # Copy test workspace
test_workspace_dir_source = Path(data_dir) / 'actevedef_718448162' test_workspace_dir_source = Path(data_dir) / 'actevedef_718448162'
test_workspace_dir = tmp_path / 'test_ocrd_cli' test_workspace_dir = tmp_path / 'test_ocrd_cli'

@ -12,14 +12,15 @@ data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
@pytest.mark.integration @pytest.mark.integration
def test_word_error_rate_between_page_files(): def test_word_error_rate_between_page_files():
# In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. → 3 changed words # In the fake OCR file, we changed 2 characters and replaced a fi ligature with fi. So we have 3 changed words,
# the ligature does not count → 2 errors
gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml'))) gt = page_text(ET.parse(os.path.join(data_dir, 'test-gt.page2018.xml')))
gt_word_count = 7+6+5+8+7+6+7+8+6+7+7+5+6+8+8+7+7+6+5+4 # Manually verified word count per line gt_word_count = 7+6+5+8+7+6+7+8+6+7+7+5+6+8+8+7+7+6+5+4 # Manually verified word count per line
assert len(list(words(gt))) == gt_word_count assert len(list(words(gt))) == gt_word_count
ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml'))) ocr = page_text(ET.parse(os.path.join(data_dir, 'test-fake-ocr.page2018.xml')))
assert word_error_rate(gt, ocr) == 3/gt_word_count assert word_error_rate(gt, ocr) == 2/gt_word_count
@pytest.mark.integration @pytest.mark.integration

@ -6,7 +6,8 @@ import textwrap
import pytest import pytest
from .. import alto_namespace, alto_text, page_namespace, page_text, text from .util import working_directory
from .. import alto_namespace, alto_text, page_namespace, page_text, plain_text, text
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
@ -49,27 +50,51 @@ def test_page_namespace():
def test_page_test(): def test_page_test():
tree = ET.parse(os.path.join(data_dir, 'test.page2018.xml')) tree = ET.parse(os.path.join(data_dir, 'test.page2018.xml'))
result = page_text(tree) result = page_text(tree)
# We are currently normalizing on extraction, so the text is normalized.
#
# expected = textwrap.dedent("""\
# ber die vielen Sorgen wegen deelben vergaß
# Hartkopf, der Frau Amtmnnin das ver⸗
# ſproene zu berliefern. — Ein Erpreer
# wurde an ihn abgeſit, um ihn ums Him⸗
# melswien zu ſagen, daß er das Verſproene
# glei den Augenbli berbringen mte, die
# Frau Amtmnnin htte  auf ihn verlaen,
# und nun wßte e nit, was e anfangen
# ſote. Den Augenbli ſote er kommen,
# ſon vergieng e in ihrer Ang. — Die
# Ge wren ſon angekommen, und es fehlte
# ihr do no an aem. —
# Hartkopf mußte  er bennen, und
# endli na langem Nadenken fiel es ihm er
# wieder ein. — Er langte den Zettel aus dem
# Accisbue heraus, und ſagte ſeiner Frau, daß
# e das, was da wre, herbeyſaffen mte.
# Jndeß mangelten do einige Generalia, die
# alſo wegfielen. — Hartkopf gieng ſelb
# mit und berbrate es. —""")
expected = textwrap.dedent("""\ expected = textwrap.dedent("""\
ber die vielen Sorgen wegen deelben vergaß über die vielen Sorgen wegen deſſelben vergaß
Hartkopf, der Frau Amtmnnin das ver Hartkopf, der Frau Amtmännin das ver-
ſproene zu berliefern. Ein Erpreer ſprochene zu überliefern. Ein Erpreſſer
wurde an ihn abgeſit, um ihn ums Him wurde an ihn abgeſchickt, um ihn ums Him-
melswien zu ſagen, daß er das Verſproene melswillen zu ſagen, daß er das Verſprochene
glei den Augenbli berbringen mte, die gleich den Augenblick überbringen möchte, die
Frau Amtmnnin htte auf ihn verlaen, Frau Amtmännin hätte ſich auf ihn verlaſſen,
und nun wßte e nit, was e anfangen und nun wüßte ſie nicht, was ſie anfangen
ſote. Den Augenbli ſote er kommen, ſollte. Den Augenblick ſollte er kommen,
ſon vergieng e in ihrer Ang. Die ſonſt vergieng ſie in ihrer Angſt. Die
Ge wren ſon angekommen, und es fehlte Gäſte wären ſchon angekommen, und es fehlte
ihr do no an aem. ihr doch noch an allem.
Hartkopf mußte er bennen, und Hartkopf mußte ſich erſt beſinnen, und
endli na langem Nadenken fiel es ihm er endlich nach langem Nachdenken fiel es ihm erſt
wieder ein. Er langte den Zettel aus dem wieder ein. Er langte den Zettel aus dem
Accisbue heraus, und ſagte ſeiner Frau, daß Accisbuche heraus, und ſagte ſeiner Frau, daß
e das, was da wre, herbeyſaffen mte. ſie das, was da wäre, herbeyſchaffen möchte.
Jndeß mangelten do einige Generalia, die Jndeß mangelten doch einige Generalia, die
alſo wegfielen. Hartkopf gieng ſelb alſo wegfielen. Hartkopf gieng ſelbſt
mit und berbrate es. """) mit und überbrachte es. """)
assert result == expected assert result == expected
@ -92,7 +117,8 @@ def test_page_order():
tree = ET.parse(os.path.join(data_dir, 'order.page.xml')) tree = ET.parse(os.path.join(data_dir, 'order.page.xml'))
result = page_text(tree) result = page_text(tree)
assert re.search(r'Herr Konfrater.*75.*Etwas f.r Wittwen.*Ein gewi.er Lord.*76\. Die', result, re.DOTALL) print(result)
assert re.search(r'Herr Konfrater.*75.*Etwas f.r Wittwen.*Ein gewi.{1,2}er Lord.*76\. Die', result, re.DOTALL)
def test_page_mixed_regions(): def test_page_mixed_regions():
@ -106,5 +132,15 @@ def test_page_mixed_regions():
def test_text(): def test_text():
assert "being erected at the Broadway stock" in text(os.path.join(data_dir, 'test.alto1.xml')) assert "being erected at the Broadway stock" in text(os.path.join(data_dir, 'test.alto1.xml'))
assert "wieder ein. Er langte den Zettel aus dem" in text(os.path.join(data_dir, 'test.page2018.xml')) assert "wieder ein. Er langte den Zettel aus dem" in text(os.path.join(data_dir, 'test.page2018.xml'))
assert "Lorem ipsum" in text(os.path.join(data_dir, 'test.txt')) assert "Lorem ipsum" in text(os.path.join(data_dir, 'test.txt'))
def test_plain(tmp_path):
with working_directory(str(tmp_path)):
with open('ocr.txt', 'w') as ocrf:
ocrf.write('AAAAB')
result = plain_text('ocr.txt')
expected = 'AAAAB'
assert result == expected

@ -21,8 +21,8 @@ def diffprint(x, y):
_diffprint(x, y) _diffprint(x, y)
def unzip(l): def unzip(an_iterable_of_tuples):
return zip(*l) return zip(*an_iterable_of_tuples)
class working_directory: class working_directory:

@ -1,14 +1,19 @@
from __future__ import division from __future__ import division
import unicodedata import unicodedata
from typing import Tuple from typing import Tuple, Iterable
from multimethod import multimethod
import uniseg.wordbreak import uniseg.wordbreak
from .edit_distance import levenshtein from .edit_distance import levenshtein
from . import ExtractedText
def words(s): @multimethod
def words(s: str):
"""Extract words from a string"""
# Patch uniseg.wordbreak.word_break to deal with our private use characters. See also # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also
# https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt
old_word_break = uniseg.wordbreak.word_break old_word_break = uniseg.wordbreak.word_break
@ -41,17 +46,37 @@ def words(s):
yield word yield word
def words_normalized(s): @multimethod
def words(s: ExtractedText):
return words(s.text)
@multimethod
def words_normalized(s: str):
return words(unicodedata.normalize('NFC', s)) return words(unicodedata.normalize('NFC', s))
def word_error_rate_n(reference, compared) -> Tuple[float, int]: @multimethod
if isinstance(reference, str): def words_normalized(s: ExtractedText):
reference_seq = list(words_normalized(reference)) return words_normalized(s.text)
compared_seq = list(words_normalized(compared))
else:
reference_seq = list(reference) @multimethod
compared_seq = list(compared) def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
reference_seq = list(words_normalized(reference))
compared_seq = list(words_normalized(compared))
return word_error_rate_n(reference_seq, compared_seq)
@multimethod
def word_error_rate_n(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)
@multimethod
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)
d = levenshtein(reference_seq, compared_seq) d = levenshtein(reference_seq, compared_seq)
n = len(reference_seq) n = len(reference_seq)

@ -6,3 +6,5 @@ numpy
colorama colorama
MarkupSafe MarkupSafe
ocrd >= 2.13.1 ocrd >= 2.13.1
attrs
multimethod == 1.3 # latest version to officially support Python 3.5

Loading…
Cancel
Save