From c79b73dcc8f31e8c27655e165a61413aca5fedb0 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Tue, 2 Jun 2026 20:26:42 +0200 Subject: [PATCH] =?UTF-8?q?cnn-rnn-ocr:=20move=20CTC=20decoder=20and=20str?= =?UTF-8?q?ing=20decoder=20to=20inference=20model=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ModelZoo: drop `num_to_char` and `characters` model types, also drop `_load_characters()` and `_load_num_to_char()` loaders - `ModelZoo.load_models()`: use Predictor for `ocr` models, too - `ModelZoo.load_model()`: delegate runtime/inference conversion of OCR models to `eynollah.training.models.cnn_rnn_ocr_model4inference` - `training.models`: add (purely functional) Keras layer `CTCDecoder` for inference on top of softmax output, but using TF backend function instead of (broken) `Keras.backend.ctc_decode()`, while switching to beam search (instead of greedy) and also returning decoded path probability - `training.models.cnn_rnn_ocr_model()` w/ `inference=True`: * add kwarg `characters_txt_file` for file path of character set * configure secondary tensor path on OCR graph for binarized input (additional input `image_bin`, averaging softmax outputs) * use new `CTCDecoder` layer and inverse `StringLookup` layer to decode from softmax output to tf.string; so inference models now have 2 inputs (RGB, binarized) and 2 outputs (text, prob) * since `np.dtype=object` cannot be handled by SharedMemory (as needed by Predictor queues), also replace tf.string by tf.uint8 arrays * use this for `training convert` for OCR models w/ `--rebuild` - `training.models.cnn_rnn_ocr_model4inference`: * new function which does the same but loads an existing OCR model in training configuration (i.e. without prior `inference=True`) * use this for `training convert` for OCR models w/o `--rebuild` --- src/eynollah/model_zoo/default_specs.py | 16 ----- src/eynollah/model_zoo/model_zoo.py | 42 ++--------- src/eynollah/training/convert.py | 12 ++-- src/eynollah/training/models.py | 93 ++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 63 deletions(-) diff --git a/src/eynollah/model_zoo/default_specs.py b/src/eynollah/model_zoo/default_specs.py index dc725e4..170d944 100644 --- a/src/eynollah/model_zoo/default_specs.py +++ b/src/eynollah/model_zoo/default_specs.py @@ -208,22 +208,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([ type='Keras', ), - EynollahModelSpec( - category="num_to_char", - variant='', - filename="characters_org.txt", - dist_url=dist_url("ocr"), - type='decoder', - ), - - EynollahModelSpec( - category="characters", - variant='', - filename="characters_org.txt", - dist_url=dist_url("ocr"), - type='List[str]', - ), - EynollahModelSpec( category="ocr", variant='tr', diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 0a68203..e7d21aa 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -123,12 +123,8 @@ class EynollahModelZoo: model_category = model_category[:-8] load_kwargs["patched"] = True - if model_category == 'ocr': + if model_category == 'ocr' and model_variant == 'tr': model = self._load_ocr_model(variant=model_variant, device=device) - elif model_category == 'num_to_char': - model = self._load_num_to_char() - elif model_category == 'characters': - model = self._load_characters() elif model_category == 'trocr_processor': from transformers import TrOCRProcessor model_path = self.model_path(model_category, model_variant) @@ -232,9 +228,12 @@ class EynollahModelZoo: from tensorflow.keras.models import load_model from tensorflow.keras.models import Model as KerasModel + from ..training.models import cnn_rnn_ocr_model4inference + self._configure_tf_device(model_category, device=device) model = load_model(model_path, compile=False) + assert isinstance(model, KerasModel) # from ..patch_encoder import ( # wrap_layout_model_patched, @@ -249,15 +248,7 @@ class EynollahModelZoo: if model_category == 'ocr': # cnn-rnn-ocr task model may not be in inference mode, yet - try: - model.get_layer(name='ctc_loss') - except ValueError: - pass - else: - model = KerasModel( - model.get_layer(name="image").input, # type: ignore - model.get_layer(name="dense2").output, # type: ignore - ) + model = cnn_rnn_ocr_model4inference(model, model_path) model.make_predict_function() @@ -369,29 +360,6 @@ class EynollahModelZoo: return self.load_model('ocr', model_variant=variant, device=device) - def _load_characters(self) -> List[str]: - """ - Load encoding for OCR - """ - with open(self.model_path('num_to_char'), "r") as config_file: - return json.load(config_file) - - def _load_num_to_char(self) -> 'StringLookup': - """ - Load decoder for OCR - """ - os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 - from ocrd_utils import tf_disable_interactive_logs - tf_disable_interactive_logs() - - from tensorflow.keras.layers import StringLookup - - characters = self._load_characters() - # Mapping characters to integers. - char_to_num = StringLookup(vocabulary=characters, mask_token=None) - # Mapping integers back to original characters. - return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True) - def __str__(self): return tabulate( [ diff --git a/src/eynollah/training/convert.py b/src/eynollah/training/convert.py index dd4271f..140079e 100644 --- a/src/eynollah/training/convert.py +++ b/src/eynollah/training/convert.py @@ -2,6 +2,7 @@ import os from pathlib import Path from shutil import copy2 import logging +import json import click @@ -74,18 +75,13 @@ def convert_cli(rebuild, format_, in_, out): model = get_model(config, logging.root) model.load_weights(model_path).assert_existing_objects_matched().expect_partial() else: + from .models import cnn_rnn_ocr_model4inference + model = load_model(model_path, compile=False) if isinstance(model, KerasModel): # cnn-rnn-ocr task deviates between training and inference - try: - model.get_layer(name='ctc_loss') - except ValueError: - pass - else: - model = KerasModel( - model.get_layer(name='image').input, - model.get_layer(name='dense2').output) + model = cnn_rnn_ocr_model4inference(model, model_path) if format_ in ["hdf5", "keras", "tf"]: kwargs = {"save_format": {"hdf5": "h5"}.get(format_, format_)} diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 83058ee..528c848 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -1,4 +1,5 @@ import os +import json os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf @@ -23,6 +24,7 @@ from tensorflow.keras.layers import ( Reshape, UpSampling2D, ZeroPadding2D, + StringLookup, add, concatenate ) @@ -57,6 +59,50 @@ class CTCLayer(Layer): # At test time, just return the computed predictions. return y_pred + +class CTCDecoder(Layer): + def call(self, inputs): + n_samples = tf.shape(inputs)[0] + n_steps = inputs.shape[1] + n_classes = inputs.shape[2] + lengths = tf.ones(n_samples, dtype=tf.int32) * n_steps + ## Keras beam search seems to mess with double letters + ## but Keras greedy sometimes removes arbitrary letters + # outputs, logits = tf.keras.backend.ctc_decode(inputs, + # lengths, + # beam_width=20 + # greedy=False, # True, + # # backend does not allow these kwargs + # #merge_repeated=False, + # #mask_index=inputs.shape[2]-1, + # ) + # tf.nn.ctc_*_decoder (in contrast to tf.keras.backend.ctc_decode) + # needs logits instead of probs and time-major (batch 2nd dim) + inputs = tf.math.log( + tf.transpose(inputs, perm=[1, 0, 2]) + tf.keras.backend.epsilon() + ) + # tf.nn.ctc_greedy_decoder() is not as precise + # tf.compat.v1.nn.ctc_beam_search_decoder() also needs merge_repeated=False + decoded, logits = tf.nn.ctc_beam_search_decoder( + inputs, + lengths, + beam_width=10, + top_paths=2, + ) + # get top path for all sequences in batch + decoded = decoded[0] + logits = logits[:, 0] - logits[:, 1] + # convert to dense + outputs = tf.SparseTensor(decoded.indices, decoded.values, + (n_samples, n_steps)) + outputs = tf.sparse.to_dense(sp_input=outputs, default_value=-1) + # # drop non-tokens (-1) and OOV (0) + # result = [] + # for output in outputs: + # result.append(tf.gather(output, tf.where(output > 0))) + # outputs = tf.stack(result) + probs = tf.exp(-logits) + return outputs, probs def mlp(x, hidden_units, dropout_rate): for units in hidden_units: @@ -422,7 +468,7 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224 return model -def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False): +def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None): inputs = Input(shape=(image_height, image_width, 3), name="image") labels = Input(name="label", shape=(None,)) @@ -492,13 +538,56 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l out = Dense(n_classes, activation="softmax", name="dense2")(out) if inference: - return Model(inputs, out) + # add second path for binarization + inputs_bin = Input(shape=(image_height, image_width, 3), name="image_bin") + out_bin = Model(inputs, out)(inputs_bin) + # ensemble raw results + out = 0.5 * (out + out_bin) + # get tf.string batch + out, prob = CTCDecoder()(out) + # decode int to str + with open(characters_txt_file, "r") as voc_file: + voc = json.load(voc_file) + char2num = StringLookup(vocabulary=voc) + voc = char2num.get_vocabulary() + num2char = StringLookup(vocabulary=voc, invert=True) + output = num2char(out) + # avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed) + output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc))) + + return Model((inputs, inputs_bin), (output, prob)) # Add CTC layer for calculating CTC loss at each step. out = CTCLayer(name="ctc_loss")(labels, out) return Model((inputs, labels), out) +def cnn_rnn_ocr_model4inference(model, model_path): + """convert trained cnn-rnn-ocr model to inference model post-hoc""" + try: + model.get_layer(name='ctc_loss') + except ValueError: + # likely already converted + return model + else: + inputs = model.get_layer(name='image').input + output = model.get_layer(name='dense2').output + inputs_bin = Input(inputs.shape[1:], name='image_bin') + output_bin = Model(inputs, output)(inputs_bin) + output = 0.5 * (output + output_bin) + output, prob = CTCDecoder()(output) + with open(model_path / "characters_org.txt", "r") as voc_file: + voc = json.load(voc_file) + char2num = StringLookup(vocabulary=voc) + voc = char2num.get_vocabulary() + num2char = StringLookup(vocabulary=voc, invert=True) + output = num2char(output) + # avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed) + output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc))) + inputs = (inputs, inputs_bin) + outputs = (output, prob) + return Model(inputs, outputs) + def get_model(config, logger): from sacred.config import create_captured_function