diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index ba61764..4652b07 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -29,6 +29,7 @@ from tensorflow.keras.layers import ( ) from tensorflow.keras.models import Model from tensorflow.keras.regularizers import l2 +from tensorflow.keras.backend import ctc_batch_cost from ..patch_encoder import Patches, PatchEncoder @@ -45,10 +46,6 @@ MERGE_AXIS = -1 class CTCLayer(Layer): - def __init__(self, name=None): - super().__init__(name=name) - self.loss_fn = tf.keras.backend.ctc_batch_cost - def call(self, y_true, y_pred): batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") @@ -56,7 +53,7 @@ class CTCLayer(Layer): input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") - loss = self.loss_fn(y_true, y_pred, input_length, label_length) + loss = ctc_batch_cost(y_true, y_pred, input_length, label_length) self.add_loss(loss) # At test time, just return the computed predictions. @@ -505,6 +502,6 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s # Add CTC layer for calculating CTC loss at each step. output = CTCLayer(name="ctc_loss")(labels, out) - model = Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer") + model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer") return model