1
0
Fork 0
mirror of https://github.com/mikegerber/ocrd_calamari.git synced 2025-06-10 04:09:53 +02:00

Merge branch 'master' into image-features

This commit is contained in:
Gerber, Mike 2021-02-09 18:17:23 +01:00
commit c0902cdef5
12 changed files with 385 additions and 105 deletions

View file

@ -1,6 +1,6 @@
{
"git_url": "https://github.com/kba/ocrd_calamari",
"version": "0.0.3",
"git_url": "https://github.com/OCR-D/ocrd_calamari",
"version": "1.0.1",
"tools": {
"ocrd-calamari-recognize": {
"executable": "ocrd-calamari-recognize",
@ -18,6 +18,10 @@
"OCR-D-OCR-CALAMARI"
],
"parameters": {
"checkpoint_dir": {
"description": "The directory containing calamari model files (*.ckpt.json). Uses all checkpoints in that directory",
"type": "string", "format": "file", "cacheable": true, "default": "qurator-gt4histocr-1.0"
},
"checkpoint": {
"description": "The calamari model files (*.ckpt.json)",
"type": "string", "format": "file", "cacheable": true
@ -25,6 +29,18 @@
"voter": {
"description": "The voting algorithm to use",
"type": "string", "default": "confidence_voter_default_ctc"
},
"textequiv_level": {
"type": "string",
"enum": ["line", "word", "glyph"],
"default": "line",
"description": "Deepest PAGE XML hierarchy level to include TextEquiv results for"
},
"glyph_conf_cutoff": {
"type": "number",
"format": "float",
"default": 0.001,
"description": "Only include glyph alternatives with confidences above this threshold"
}
}
}

View file

@ -1,33 +1,50 @@
from __future__ import absolute_import
import os
import itertools
from glob import glob
import numpy as np
from calamari_ocr import __version__ as calamari_version
from calamari_ocr.ocr import MultiPredictor
from calamari_ocr.ocr.voting import voter_from_proto
from calamari_ocr.proto import VoterParams
from ocrd import Processor
from ocrd_modelfactory import page_from_file
from ocrd_models.ocrd_page import to_xml
from ocrd_models.ocrd_page_generateds import TextEquivType
from ocrd_utils import getLogger, concat_padded, MIMETYPE_PAGE
from ocrd_models.ocrd_page import (
LabelType, LabelsType,
MetadataItemType,
TextEquivType,
WordType, GlyphType, CoordsType,
to_xml
)
from ocrd_utils import (
getLogger, concat_padded,
coordinates_for_segment, points_from_polygon, polygon_from_x0y0x1y1,
make_file_id, assert_file_grp_cardinality,
MIMETYPE_PAGE
)
from ocrd_calamari.config import OCRD_TOOL, TF_CPP_MIN_LOG_LEVEL
log = getLogger('processor.CalamariRecognize')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = TF_CPP_MIN_LOG_LEVEL
from tensorflow import __version__ as tensorflow_version
TOOL = 'ocrd-calamari-recognize'
class CalamariRecognize(Processor):
def __init__(self, *args, **kwargs):
kwargs['ocrd_tool'] = OCRD_TOOL['tools']['ocrd-calamari-recognize']
kwargs['version'] = OCRD_TOOL['version']
kwargs['ocrd_tool'] = OCRD_TOOL['tools'][TOOL]
kwargs['version'] = '%s (calamari %s, tensorflow %s)' % (OCRD_TOOL['version'], calamari_version, tensorflow_version)
super(CalamariRecognize, self).__init__(*args, **kwargs)
def _init_calamari(self):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = TF_CPP_MIN_LOG_LEVEL
if not self.parameter.get('checkpoint', None) and self.parameter.get('checkpoint_dir', None):
resolved = self.resolve_resource(self.parameter['checkpoint_dir'])
self.parameter['checkpoint'] = '%s/*.ckpt.json' % resolved
checkpoints = glob(self.parameter['checkpoint'])
self.predictor = MultiPredictor(checkpoints=checkpoints)
@ -43,16 +60,14 @@ class CalamariRecognize(Processor):
voter_params.type = VoterParams.Type.Value(self.parameter['voter'].upper())
self.voter = voter_from_proto(voter_params)
def _make_file_id(self, input_file, n):
file_id = input_file.ID.replace(self.input_file_grp, self.output_file_grp)
if file_id == input_file.ID:
file_id = concat_padded(self.output_file_grp, n)
return file_id
def process(self):
"""
Performs the recognition.
"""
log = getLogger('processor.CalamariRecognize')
assert_file_grp_cardinality(self.input_file_grp, 1)
assert_file_grp_cardinality(self.output_file_grp, 1)
self._init_calamari()
@ -71,44 +86,169 @@ class CalamariRecognize(Processor):
textlines = region.get_TextLine()
log.info("About to recognize %i lines of region '%s'", len(textlines), region.id)
line_images_np = []
for line in textlines:
log.debug("Recognizing line '%s' in region '%s'", line.id, region.id)
line_image, line_coords = self.workspace.image_from_segment(
line, region_image, region_coords, feature_selector=self.features)
if ('binarized' not in line_coords['features'] and
'grayscale_normalized' not in line_coords['features'] and
self.input_channels == 1):
line_image, line_coords = self.workspace.image_from_segment(line, region_image, region_coords, feature_selector=self.features)
if ('binarized' not in line_coords['features'] and 'grayscale_normalized' not in line_coords['features'] and self.input_channels == 1):
# We cannot use a feature selector for this since we don't
# know whether the model expects (has been trained on)
# binarized or grayscale images; but raw images are likely
# always inadequate:
log.warning("Using raw image for line '%s' in region '%s'",
line.id, region.id)
line_image_np = np.array(line_image, dtype=np.uint8)
log.warning("Using raw image for line '%s' in region '%s'", line.id, region.id)
line_image = line_image if all(line_image.size) else [[0]]
line_image_np = np.array(line_image, dtype=np.uint8)
line_images_np.append(line_image_np)
raw_results_all = self.predictor.predict_raw(line_images_np, progress_bar=False)
for line, raw_results in zip(textlines, raw_results_all):
raw_results = list(self.predictor.predict_raw([line_image_np], progress_bar=False))[0]
for i, p in enumerate(raw_results):
p.prediction.id = "fold_{}".format(i)
prediction = self.voter.vote_prediction_result(raw_results)
prediction.id = "voted"
line_text = prediction.sentence
line_conf = prediction.avg_char_probability
# Build line text on our own
#
# Calamari does whitespace post-processing on prediction.sentence, while it does not do the same
# on prediction.positions. Do it on our own to have consistency.
#
# XXX Check Calamari's built-in post-processing on prediction.sentence
def _sort_chars(p):
"""Filter and sort chars of prediction p"""
chars = p.chars
chars = [c for c in chars if c.char] # XXX Note that omission probabilities are not normalized?!
chars = [c for c in chars if c.probability >= self.parameter['glyph_conf_cutoff']]
chars = sorted(chars, key=lambda k: k.probability, reverse=True)
return chars
def _drop_leading_spaces(positions):
return list(itertools.dropwhile(lambda p: _sort_chars(p)[0].char == " ", positions))
def _drop_trailing_spaces(positions):
return list(reversed(_drop_leading_spaces(reversed(positions))))
def _drop_double_spaces(positions):
def _drop_double_spaces_generator(positions):
last_was_space = False
for p in positions:
if p.chars[0].char == " ":
if not last_was_space:
yield p
last_was_space = True
else:
yield p
last_was_space = False
return list(_drop_double_spaces_generator(positions))
positions = prediction.positions
positions = _drop_leading_spaces(positions)
positions = _drop_trailing_spaces(positions)
positions = _drop_double_spaces(positions)
positions = list(positions)
line_text = ''.join(_sort_chars(p)[0].char for p in positions)
if line_text != prediction.sentence:
log.warning("Our own line text is not the same as Calamari's: '%s' != '%s'",
line_text, prediction.sentence)
# Delete existing results
if line.get_TextEquiv():
log.warning("Line '%s' already contained text results", line.id)
line.set_TextEquiv([TextEquivType(Unicode=line_text, conf=line_conf)])
line.set_TextEquiv([])
if line.get_Word():
log.warning("Line '%s' already contained word segmentation", line.id)
line.set_Word([])
# Save line results
line_conf = prediction.avg_char_probability
line.set_TextEquiv([TextEquivType(Unicode=line_text, conf=line_conf)])
# Save word results
#
# Calamari OCR does not provide word positions, so we infer word positions from a. text segmentation
# and b. the glyph positions. This is necessary because the PAGE XML format enforces a strict
# hierarchy of lines > words > glyphs.
def _words(s):
"""Split words based on spaces and include spaces as 'words'"""
spaces = None
word = ''
for c in s:
if c == ' ' and spaces is True:
word += c
elif c != ' ' and spaces is False:
word += c
else:
if word:
yield word
word = c
spaces = (c == ' ')
yield word
if self.parameter['textequiv_level'] in ['word', 'glyph']:
word_no = 0
i = 0
for word_text in _words(line_text):
word_length = len(word_text)
if not all(c == ' ' for c in word_text):
word_positions = positions[i:i+word_length]
word_start = word_positions[0].global_start
word_end = word_positions[-1].global_end
polygon = polygon_from_x0y0x1y1([word_start, 0, word_end, line_image.height])
points = points_from_polygon(coordinates_for_segment(polygon, None, line_coords))
# XXX Crop to line polygon?
word = WordType(id='%s_word%04d' % (line.id, word_no), Coords=CoordsType(points))
word.add_TextEquiv(TextEquivType(Unicode=word_text))
if self.parameter['textequiv_level'] == 'glyph':
for glyph_no, p in enumerate(word_positions):
glyph_start = p.global_start
glyph_end = p.global_end
polygon = polygon_from_x0y0x1y1([glyph_start, 0, glyph_end, line_image.height])
points = points_from_polygon(coordinates_for_segment(polygon, None, line_coords))
glyph = GlyphType(id='%s_glyph%04d' % (word.id, glyph_no), Coords=CoordsType(points))
# Add predictions (= TextEquivs)
char_index_start = 1 # Must start with 1, see https://ocr-d.github.io/page#multiple-textequivs
for char_index, char in enumerate(_sort_chars(p), start=char_index_start):
glyph.add_TextEquiv(TextEquivType(Unicode=char.char, index=char_index, conf=char.probability))
word.add_Glyph(glyph)
line.add_Word(word)
word_no += 1
i += word_length
_page_update_higher_textequiv_levels('line', pcgts)
file_id = self._make_file_id(input_file, n)
# Add metadata about this operation and its runtime parameters:
metadata = pcgts.get_Metadata() # ensured by from_file()
metadata.add_MetadataItem(
MetadataItemType(type_="processingStep",
name=self.ocrd_tool['steps'][0],
value=TOOL,
Labels=[LabelsType(
externalModel="ocrd-tool",
externalId="parameters",
Label=[LabelType(type_=name, value=self.parameter[name])
for name in self.parameter.keys()])]))
file_id = make_file_id(input_file, self.output_file_grp)
pcgts.set_pcGtsId(file_id)
self.workspace.add_file(
ID=file_id,
file_grp=self.output_file_grp,
@ -151,3 +291,5 @@ def _page_update_higher_textequiv_levels(level, pcgts):
else u'' for line in lines)
region.set_TextEquiv(
[TextEquivType(Unicode=region_unicode)]) # remove old
# vim:tw=120: