mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
cnn-rnn-ocr: move CTC decoder and string decoder to inference model…
- ModelZoo: drop `num_to_char` and `characters` model types,
also drop `_load_characters()` and `_load_num_to_char()` loaders
- `ModelZoo.load_models()`: use Predictor for `ocr` models, too
- `ModelZoo.load_model()`: delegate runtime/inference conversion of
OCR models to `eynollah.training.models.cnn_rnn_ocr_model4inference`
- `training.models`: add (purely functional) Keras layer `CTCDecoder`
for inference on top of softmax output, but using TF backend
function instead of (broken) `Keras.backend.ctc_decode()`, while
switching to beam search (instead of greedy) and also returning
decoded path probability
- `training.models.cnn_rnn_ocr_model()` w/ `inference=True`:
* add kwarg `characters_txt_file` for file path of character set
* configure secondary tensor path on OCR graph for binarized input
(additional input `image_bin`, averaging softmax outputs)
* use new `CTCDecoder` layer and inverse `StringLookup` layer to
decode from softmax output to tf.string; so inference models
now have 2 inputs (RGB, binarized) and 2 outputs (text, prob)
* since `np.dtype=object` cannot be handled by SharedMemory (as
needed by Predictor queues), also replace tf.string by tf.uint8
arrays
* use this for `training convert` for OCR models w/ `--rebuild`
- `training.models.cnn_rnn_ocr_model4inference`:
* new function which does the same but loads an existing OCR model
in training configuration (i.e. without prior `inference=True`)
* use this for `training convert` for OCR models w/o `--rebuild`
This commit is contained in:
parent
13f2f81c45
commit
c79b73dcc8
4 changed files with 100 additions and 63 deletions
|
|
@ -208,22 +208,6 @@ DEFAULT_MODEL_SPECS = EynollahModelSpecSet([
|
|||
type='Keras',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="num_to_char",
|
||||
variant='',
|
||||
filename="characters_org.txt",
|
||||
dist_url=dist_url("ocr"),
|
||||
type='decoder',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="characters",
|
||||
variant='',
|
||||
filename="characters_org.txt",
|
||||
dist_url=dist_url("ocr"),
|
||||
type='List[str]',
|
||||
),
|
||||
|
||||
EynollahModelSpec(
|
||||
category="ocr",
|
||||
variant='tr',
|
||||
|
|
|
|||
|
|
@ -123,12 +123,8 @@ class EynollahModelZoo:
|
|||
model_category = model_category[:-8]
|
||||
load_kwargs["patched"] = True
|
||||
|
||||
if model_category == 'ocr':
|
||||
if model_category == 'ocr' and model_variant == 'tr':
|
||||
model = self._load_ocr_model(variant=model_variant, device=device)
|
||||
elif model_category == 'num_to_char':
|
||||
model = self._load_num_to_char()
|
||||
elif model_category == 'characters':
|
||||
model = self._load_characters()
|
||||
elif model_category == 'trocr_processor':
|
||||
from transformers import TrOCRProcessor
|
||||
model_path = self.model_path(model_category, model_variant)
|
||||
|
|
@ -232,9 +228,12 @@ class EynollahModelZoo:
|
|||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.models import Model as KerasModel
|
||||
|
||||
from ..training.models import cnn_rnn_ocr_model4inference
|
||||
|
||||
self._configure_tf_device(model_category, device=device)
|
||||
|
||||
model = load_model(model_path, compile=False)
|
||||
assert isinstance(model, KerasModel)
|
||||
|
||||
# from ..patch_encoder import (
|
||||
# wrap_layout_model_patched,
|
||||
|
|
@ -249,15 +248,7 @@ class EynollahModelZoo:
|
|||
|
||||
if model_category == 'ocr':
|
||||
# cnn-rnn-ocr task model may not be in inference mode, yet
|
||||
try:
|
||||
model.get_layer(name='ctc_loss')
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
model = KerasModel(
|
||||
model.get_layer(name="image").input, # type: ignore
|
||||
model.get_layer(name="dense2").output, # type: ignore
|
||||
)
|
||||
model = cnn_rnn_ocr_model4inference(model, model_path)
|
||||
|
||||
model.make_predict_function()
|
||||
|
||||
|
|
@ -369,29 +360,6 @@ class EynollahModelZoo:
|
|||
|
||||
return self.load_model('ocr', model_variant=variant, device=device)
|
||||
|
||||
def _load_characters(self) -> List[str]:
|
||||
"""
|
||||
Load encoding for OCR
|
||||
"""
|
||||
with open(self.model_path('num_to_char'), "r") as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def _load_num_to_char(self) -> 'StringLookup':
|
||||
"""
|
||||
Load decoder for OCR
|
||||
"""
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
tf_disable_interactive_logs()
|
||||
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
|
||||
characters = self._load_characters()
|
||||
# Mapping characters to integers.
|
||||
char_to_num = StringLookup(vocabulary=characters, mask_token=None)
|
||||
# Mapping integers back to original characters.
|
||||
return StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
|
||||
|
||||
def __str__(self):
|
||||
return tabulate(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
from pathlib import Path
|
||||
from shutil import copy2
|
||||
import logging
|
||||
import json
|
||||
|
||||
import click
|
||||
|
||||
|
|
@ -74,18 +75,13 @@ def convert_cli(rebuild, format_, in_, out):
|
|||
model = get_model(config, logging.root)
|
||||
model.load_weights(model_path).assert_existing_objects_matched().expect_partial()
|
||||
else:
|
||||
from .models import cnn_rnn_ocr_model4inference
|
||||
|
||||
model = load_model(model_path, compile=False)
|
||||
|
||||
if isinstance(model, KerasModel):
|
||||
# cnn-rnn-ocr task deviates between training and inference
|
||||
try:
|
||||
model.get_layer(name='ctc_loss')
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
model = KerasModel(
|
||||
model.get_layer(name='image').input,
|
||||
model.get_layer(name='dense2').output)
|
||||
model = cnn_rnn_ocr_model4inference(model, model_path)
|
||||
|
||||
if format_ in ["hdf5", "keras", "tf"]:
|
||||
kwargs = {"save_format": {"hdf5": "h5"}.get(format_, format_)}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||
import tensorflow as tf
|
||||
|
|
@ -23,6 +24,7 @@ from tensorflow.keras.layers import (
|
|||
Reshape,
|
||||
UpSampling2D,
|
||||
ZeroPadding2D,
|
||||
StringLookup,
|
||||
add,
|
||||
concatenate
|
||||
)
|
||||
|
|
@ -57,6 +59,50 @@ class CTCLayer(Layer):
|
|||
|
||||
# At test time, just return the computed predictions.
|
||||
return y_pred
|
||||
|
||||
class CTCDecoder(Layer):
|
||||
def call(self, inputs):
|
||||
n_samples = tf.shape(inputs)[0]
|
||||
n_steps = inputs.shape[1]
|
||||
n_classes = inputs.shape[2]
|
||||
lengths = tf.ones(n_samples, dtype=tf.int32) * n_steps
|
||||
## Keras beam search seems to mess with double letters
|
||||
## but Keras greedy sometimes removes arbitrary letters
|
||||
# outputs, logits = tf.keras.backend.ctc_decode(inputs,
|
||||
# lengths,
|
||||
# beam_width=20
|
||||
# greedy=False, # True,
|
||||
# # backend does not allow these kwargs
|
||||
# #merge_repeated=False,
|
||||
# #mask_index=inputs.shape[2]-1,
|
||||
# )
|
||||
# tf.nn.ctc_*_decoder (in contrast to tf.keras.backend.ctc_decode)
|
||||
# needs logits instead of probs and time-major (batch 2nd dim)
|
||||
inputs = tf.math.log(
|
||||
tf.transpose(inputs, perm=[1, 0, 2]) + tf.keras.backend.epsilon()
|
||||
)
|
||||
# tf.nn.ctc_greedy_decoder() is not as precise
|
||||
# tf.compat.v1.nn.ctc_beam_search_decoder() also needs merge_repeated=False
|
||||
decoded, logits = tf.nn.ctc_beam_search_decoder(
|
||||
inputs,
|
||||
lengths,
|
||||
beam_width=10,
|
||||
top_paths=2,
|
||||
)
|
||||
# get top path for all sequences in batch
|
||||
decoded = decoded[0]
|
||||
logits = logits[:, 0] - logits[:, 1]
|
||||
# convert to dense
|
||||
outputs = tf.SparseTensor(decoded.indices, decoded.values,
|
||||
(n_samples, n_steps))
|
||||
outputs = tf.sparse.to_dense(sp_input=outputs, default_value=-1)
|
||||
# # drop non-tokens (-1) and OOV (0)
|
||||
# result = []
|
||||
# for output in outputs:
|
||||
# result.append(tf.gather(output, tf.where(output > 0)))
|
||||
# outputs = tf.stack(result)
|
||||
probs = tf.exp(-logits)
|
||||
return outputs, probs
|
||||
|
||||
def mlp(x, hidden_units, dropout_rate):
|
||||
for units in hidden_units:
|
||||
|
|
@ -422,7 +468,7 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
|||
|
||||
return model
|
||||
|
||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False):
|
||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None):
|
||||
inputs = Input(shape=(image_height, image_width, 3), name="image")
|
||||
labels = Input(name="label", shape=(None,))
|
||||
|
||||
|
|
@ -492,13 +538,56 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
|||
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
||||
|
||||
if inference:
|
||||
return Model(inputs, out)
|
||||
# add second path for binarization
|
||||
inputs_bin = Input(shape=(image_height, image_width, 3), name="image_bin")
|
||||
out_bin = Model(inputs, out)(inputs_bin)
|
||||
# ensemble raw results
|
||||
out = 0.5 * (out + out_bin)
|
||||
# get tf.string batch
|
||||
out, prob = CTCDecoder()(out)
|
||||
# decode int to str
|
||||
with open(characters_txt_file, "r") as voc_file:
|
||||
voc = json.load(voc_file)
|
||||
char2num = StringLookup(vocabulary=voc)
|
||||
voc = char2num.get_vocabulary()
|
||||
num2char = StringLookup(vocabulary=voc, invert=True)
|
||||
output = num2char(out)
|
||||
# avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed)
|
||||
output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc)))
|
||||
|
||||
return Model((inputs, inputs_bin), (output, prob))
|
||||
|
||||
# Add CTC layer for calculating CTC loss at each step.
|
||||
out = CTCLayer(name="ctc_loss")(labels, out)
|
||||
|
||||
return Model((inputs, labels), out)
|
||||
|
||||
def cnn_rnn_ocr_model4inference(model, model_path):
|
||||
"""convert trained cnn-rnn-ocr model to inference model post-hoc"""
|
||||
try:
|
||||
model.get_layer(name='ctc_loss')
|
||||
except ValueError:
|
||||
# likely already converted
|
||||
return model
|
||||
else:
|
||||
inputs = model.get_layer(name='image').input
|
||||
output = model.get_layer(name='dense2').output
|
||||
inputs_bin = Input(inputs.shape[1:], name='image_bin')
|
||||
output_bin = Model(inputs, output)(inputs_bin)
|
||||
output = 0.5 * (output + output_bin)
|
||||
output, prob = CTCDecoder()(output)
|
||||
with open(model_path / "characters_org.txt", "r") as voc_file:
|
||||
voc = json.load(voc_file)
|
||||
char2num = StringLookup(vocabulary=voc)
|
||||
voc = char2num.get_vocabulary()
|
||||
num2char = StringLookup(vocabulary=voc, invert=True)
|
||||
output = num2char(output)
|
||||
# avoid output tf.dtype=string → np.dtype=object (which cannot be shm-ed)
|
||||
output = tf.io.decode_raw(output, tf.uint8, fixed_length=max(map(len, voc)))
|
||||
inputs = (inputs, inputs_bin)
|
||||
outputs = (output, prob)
|
||||
return Model(inputs, outputs)
|
||||
|
||||
def get_model(config, logger):
|
||||
from sacred.config import create_captured_function
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue