From faef1967f87fc497fa988ea4c4d0a2e454d3f633 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Thu, 28 May 2026 17:32:02 +0200 Subject: [PATCH] models.cnn_rnn_ocr_model: add `inference` option, drop model name --- src/eynollah/training/models.py | 34 ++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index c5510f8..eb621c6 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -422,11 +422,11 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224 return model -def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None): - input_img = Input(shape=(image_height, image_width, 3), name="image") +def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False): + inputs = Input(shape=(image_height, image_width, 3), name="image") labels = Input(name="label", shape=(None,)) - x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img) + x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs) x = BatchNormalization(name="bn1")(x) x = Activation("relu", name="relu1")(x) x = Conv2D(64,kernel_size=(3,3),padding="same")(x) @@ -458,44 +458,44 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s x = Activation("relu", name="relu8")(x) x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x) x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d) - new_shape = (x.shape[1]*x.shape[2], x.shape[3]) new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3]) new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3]) - + x = Reshape(new_shape, name="reshape")(x) x2d = Reshape(new_shape2, name="reshape2")(x2d) x4d = Reshape(new_shape4, name="reshape4")(x4d) - + xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x) xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d) xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d) - + xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d) xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d) - xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d) xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d) - + xrnn2dup = Reshape((xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup) xrnn4dup = Reshape((xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup) addition = Add()([xrnnorg, xrnn2dup, xrnn4dup]) - + addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition) - - out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn) + + out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn) out = BatchNormalization(name="bn9")(out) out = Activation("relu", name="relu9")(out) #out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out) out = Dense(n_classes, activation="softmax", name="dense2")(out) - # Add CTC layer for calculating CTC loss at each step. - output = CTCLayer(name="ctc_loss")(labels, out) - - model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer") + if inference: + return Model(inputs, out) + + # Add CTC layer for calculating CTC loss at each step. + out = CTCLayer(name="ctc_loss")(labels, out) + + return Model((inputs, labels), out) - return model