mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training.models: simplify CTC loss layer
This commit is contained in:
parent
92fc2bd815
commit
b399db3c00
1 changed files with 3 additions and 6 deletions
|
|
@ -29,6 +29,7 @@ from tensorflow.keras.layers import (
|
||||||
)
|
)
|
||||||
from tensorflow.keras.models import Model
|
from tensorflow.keras.models import Model
|
||||||
from tensorflow.keras.regularizers import l2
|
from tensorflow.keras.regularizers import l2
|
||||||
|
from tensorflow.keras.backend import ctc_batch_cost
|
||||||
|
|
||||||
from ..patch_encoder import Patches, PatchEncoder
|
from ..patch_encoder import Patches, PatchEncoder
|
||||||
|
|
||||||
|
|
@ -45,10 +46,6 @@ MERGE_AXIS = -1
|
||||||
|
|
||||||
|
|
||||||
class CTCLayer(Layer):
|
class CTCLayer(Layer):
|
||||||
def __init__(self, name=None):
|
|
||||||
super().__init__(name=name)
|
|
||||||
self.loss_fn = tf.keras.backend.ctc_batch_cost
|
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
def call(self, y_true, y_pred):
|
||||||
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
|
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
|
||||||
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
|
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
|
||||||
|
|
@ -56,7 +53,7 @@ class CTCLayer(Layer):
|
||||||
|
|
||||||
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
|
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
|
||||||
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
|
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
|
||||||
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
|
loss = ctc_batch_cost(y_true, y_pred, input_length, label_length)
|
||||||
self.add_loss(loss)
|
self.add_loss(loss)
|
||||||
|
|
||||||
# At test time, just return the computed predictions.
|
# At test time, just return the computed predictions.
|
||||||
|
|
@ -505,6 +502,6 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
||||||
# Add CTC layer for calculating CTC loss at each step.
|
# Add CTC layer for calculating CTC loss at each step.
|
||||||
output = CTCLayer(name="ctc_loss")(labels, out)
|
output = CTCLayer(name="ctc_loss")(labels, out)
|
||||||
|
|
||||||
model = Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer")
|
model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue