cnn-rnn-ocr: move CTC decoder and string decoder to inference model…

- ModelZoo: drop `num_to_char` and `characters` model types,
  also drop `_load_characters()` and `_load_num_to_char()` loaders
- `ModelZoo.load_models()`: use Predictor for `ocr` models, too
- `ModelZoo.load_model()`: delegate runtime/inference conversion of
  OCR models to `eynollah.training.models.cnn_rnn_ocr_model4inference`
- `training.models`: add (purely functional) Keras layer `CTCDecoder`
  for inference on top of softmax output, but using TF backend
  function instead of (broken) `Keras.backend.ctc_decode()`, while
  switching to beam search (instead of greedy) and also returning
  decoded path probability
- `training.models.cnn_rnn_ocr_model()` w/ `inference=True`:
  * add kwarg `characters_txt_file` for file path of character set
  * configure secondary tensor path on OCR graph for binarized input
    (additional input `image_bin`, averaging softmax outputs)
  * use new `CTCDecoder` layer and inverse `StringLookup` layer to
    decode from softmax output to tf.string; so inference models
    now have 2 inputs (RGB, binarized) and 2 outputs (text, prob)
  * since `np.dtype=object` cannot be handled by SharedMemory (as
    needed by Predictor queues), also replace tf.string by tf.uint8
    arrays
  * use this for `training convert` for OCR models w/ `--rebuild`
- `training.models.cnn_rnn_ocr_model4inference`:
  * new function which does the same but loads an existing OCR model
    in training configuration (i.e. without prior `inference=True`)
  * use this for `training convert` for OCR models w/o `--rebuild`
This commit is contained in:
Robert Sachunsky 2026-06-02 20:26:42 +02:00
parent 13f2f81c45
commit c79b73dcc8
4 changed files with 100 additions and 63 deletions

View file

@ -208,22 +208,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
type='Keras', type='Keras',
), ),
EynollahModelSpec(
category="num_to_char",
variant='',
filename="characters_org.txt",
dist_url=dist_url("ocr"),
type='decoder',
),
EynollahModelSpec(
category="characters",
variant='',
filename="characters_org.txt",
dist_url=dist_url("ocr"),
type='List[str]',
),
EynollahModelSpec( EynollahModelSpec(
category="ocr", category="ocr",
variant='tr', variant='tr',

View file

@ -123,12 +123,8 @@ class EynollahModelZoo:
model_category = model_category[:-8] model_category = model_category[:-8]
load_kwargs["patched"] = True load_kwargs["patched"] = True
if model_category == 'ocr': if model_category == 'ocr' and model_variant == 'tr':
model = self._load_ocr_model(variant=model_variant, device=device) model = self._load_ocr_model(variant=model_variant, device=device)
elif model_category == 'num_to_char':
model = self._load_num_to_char()
elif model_category == 'characters':
model = self._load_characters()
elif model_category == 'trocr_processor': elif model_category == 'trocr_processor':
from transformers import TrOCRProcessor from transformers import TrOCRProcessor
model_path = self.model_path(model_category, model_variant) model_path = self.model_path(model_category, model_variant)
@ -232,9 +228,12 @@ class EynollahModelZoo:
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from tensorflow.keras.models import Model as KerasModel from tensorflow.keras.models import Model as KerasModel
from ..training.models import cnn_rnn_ocr_model4inference
self._configure_tf_device(model_category, device=device) self._configure_tf_device(model_category, device=device)
model = load_model(model_path, compile=False) model = load_model(model_path, compile=False)
assert isinstance(model, KerasModel)
# from ..patch_encoder import ( # from ..patch_encoder import (
# wrap_layout_model_patched, # wrap_layout_model_patched,
@ -249,15 +248,7 @@ class EynollahModelZoo:
if model_category == 'ocr': if model_category == 'ocr':
# cnn-rnn-ocr task model may not be in inference mode, yet # cnn-rnn-ocr task model may not be in inference mode, yet
try: model = cnn_rnn_ocr_model4inference(model, model_path)
model.get_layer(name='ctc_loss')
except ValueError:
pass
else:
model = KerasModel(
model.get_layer(name="image").input, # type: ignore
model.get_layer(name="dense2").output, # type: ignore
)
model.make_predict_function() model.make_predict_function()
@ -369,29 +360,6 @@ class EynollahModelZoo:
return self.load_model('ocr', model_variant=variant, device=device) return self.load_model('ocr', model_variant=variant, device=device)
def _load_characters(self) -> List[str]:
"""
Load encoding for OCR
"""
with open(self.model_path('num_to_char'), "r") as config_file:
return json.load(config_file)
def _load_num_to_char(self) -> 'StringLookup':
"""
Load decoder for OCR
"""
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
from tensorflow.keras.layers import StringLookup
characters = self._load_characters()
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
# Mapping integers back to original characters.
return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
def __str__(self): def __str__(self):
return tabulate( return tabulate(
[ [

View file

@ -2,6 +2,7 @@ import os
from pathlib import Path from pathlib import Path
from shutil import copy2 from shutil import copy2
import logging import logging
import json
import click import click
@ -74,18 +75,13 @@ def convert_cli(rebuild, format_, in_, out):
model = get_model(config, logging.root) model = get_model(config, logging.root)
model.load_weights(model_path).assert_existing_objects_matched().expect_partial() model.load_weights(model_path).assert_existing_objects_matched().expect_partial()
else: else:
from .models import cnn_rnn_ocr_model4inference
model = load_model(model_path, compile=False) model = load_model(model_path, compile=False)
if isinstance(model, KerasModel): if isinstance(model, KerasModel):
# cnn-rnn-ocr task deviates between training and inference # cnn-rnn-ocr task deviates between training and inference
try: model = cnn_rnn_ocr_model4inference(model, model_path)
model.get_layer(name='ctc_loss')
except ValueError:
pass
else:
model = KerasModel(
model.get_layer(name='image').input,
model.get_layer(name='dense2').output)
if format_ in ["hdf5", "keras", "tf"]: if format_ in ["hdf5", "keras", "tf"]:
kwargs = {"save_format": {"hdf5": "h5"}.get(format_, format_)} kwargs = {"save_format": {"hdf5": "h5"}.get(format_, format_)}

View file

@ -1,4 +1,5 @@
import os import os
import json
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf import tensorflow as tf
@ -23,6 +24,7 @@ from tensorflow.keras.layers import (
Reshape, Reshape,
UpSampling2D, UpSampling2D,
ZeroPadding2D, ZeroPadding2D,
StringLookup,
add, add,
concatenate concatenate
) )
@ -57,6 +59,50 @@ class CTCLayer(Layer):
# At test time, just return the computed predictions. # At test time, just return the computed predictions.
return y_pred return y_pred
class CTCDecoder(Layer):
def call(self, inputs):
n_samples = tf.shape(inputs)[0]
n_steps = inputs.shape[1]
n_classes = inputs.shape[2]
lengths = tf.ones(n_samples, dtype=tf.int32) * n_steps
## Keras beam search seems to mess with double letters
## but Keras greedy sometimes removes arbitrary letters
# outputs, logits = tf.keras.backend.ctc_decode(inputs,
# lengths,
# beam_width=20
# greedy=False, # True,
# # backend does not allow these kwargs
# #merge_repeated=False,
# #mask_index=inputs.shape[2]-1,
# )
# tf.nn.ctc_*_decoder (in contrast to tf.keras.backend.ctc_decode)
# needs logits instead of probs and time-major (batch 2nd dim)
inputs = tf.math.log(
tf.transpose(inputs, perm=[1, 0, 2]) + tf.keras.backend.epsilon()
)
# tf.nn.ctc_greedy_decoder() is not as precise
# tf.compat.v1.nn.ctc_beam_search_decoder() also needs merge_repeated=False
decoded, logits = tf.nn.ctc_beam_search_decoder(
inputs,
lengths,
beam_width=10,
top_paths=2,
)
# get top path for all sequences in batch
decoded = decoded[0]
logits = logits[:, 0] - logits[:, 1]
# convert to dense
outputs = tf.SparseTensor(decoded.indices, decoded.values,
(n_samples, n_steps))
outputs = tf.sparse.to_dense(sp_input=outputs, default_value=-1)
# # drop non-tokens (-1) and OOV (0)
# result = []
# for output in outputs:
# result.append(tf.gather(output, tf.where(output > 0)))
# outputs = tf.stack(result)
probs = tf.exp(-logits)
return outputs, probs
def mlp(x, hidden_units, dropout_rate): def mlp(x, hidden_units, dropout_rate):
for units in hidden_units: for units in hidden_units:
@ -422,7 +468,7 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
return model return model
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False): def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None):
inputs = Input(shape=(image_height, image_width, 3), name="image") inputs = Input(shape=(image_height, image_width, 3), name="image")
labels = Input(name="label", shape=(None,)) labels = Input(name="label", shape=(None,))
@ -492,13 +538,56 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
out = Dense(n_classes, activation="softmax", name="dense2")(out) out = Dense(n_classes, activation="softmax", name="dense2")(out)
if inference: if inference:
return Model(inputs, out) # add second path for binarization
inputs_bin = Input(shape=(image_height, image_width, 3), name="image_bin")
out_bin = Model(inputs, out)(inputs_bin)
# ensemble raw results
out = 0.5 * (out + out_bin)
# get tf.string batch
out, prob = CTCDecoder()(out)
# decode int to str
with open(characters_txt_file, "r") as voc_file:
voc = json.load(voc_file)
char2num = StringLookup(vocabulary=voc)
voc = char2num.get_vocabulary()
num2char = StringLookup(vocabulary=voc, invert=True)
output = num2char(out)
# avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed)
output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc)))
return Model((inputs, inputs_bin), (output, prob))
# Add CTC layer for calculating CTC loss at each step. # Add CTC layer for calculating CTC loss at each step.
out = CTCLayer(name="ctc_loss")(labels, out) out = CTCLayer(name="ctc_loss")(labels, out)
return Model((inputs, labels), out) return Model((inputs, labels), out)
def cnn_rnn_ocr_model4inference(model, model_path):
"""convert trained cnn-rnn-ocr model to inference model post-hoc"""
try:
model.get_layer(name='ctc_loss')
except ValueError:
# likely already converted
return model
else:
inputs = model.get_layer(name='image').input
output = model.get_layer(name='dense2').output
inputs_bin = Input(inputs.shape[1:], name='image_bin')
output_bin = Model(inputs, output)(inputs_bin)
output = 0.5 * (output + output_bin)
output, prob = CTCDecoder()(output)
with open(model_path / "characters_org.txt", "r") as voc_file:
voc = json.load(voc_file)
char2num = StringLookup(vocabulary=voc)
voc = char2num.get_vocabulary()
num2char = StringLookup(vocabulary=voc, invert=True)
output = num2char(output)
# avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed)
output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc)))
inputs = (inputs, inputs_bin)
outputs = (output, prob)
return Model(inputs, outputs)
def get_model(config, logger): def get_model(config, logger):
from sacred.config import create_captured_function from sacred.config import create_captured_function