models.cnn_rnn_ocr_model: add inference option, drop model name

This commit is contained in:
Robert Sachunsky 2026-05-28 17:32:02 +02:00
parent c4a7eec5b3
commit faef1967f8

View file

@ -422,11 +422,11 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
return model return model
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None): def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False):
input_img = Input(shape=(image_height, image_width, 3), name="image") inputs = Input(shape=(image_height, image_width, 3), name="image")
labels = Input(name="label", shape=(None,)) labels = Input(name="label", shape=(None,))
x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img) x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs)
x = BatchNormalization(name="bn1")(x) x = BatchNormalization(name="bn1")(x)
x = Activation("relu", name="relu1")(x) x = Activation("relu", name="relu1")(x)
x = Conv2D(64,kernel_size=(3,3),padding="same")(x) x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
@ -459,7 +459,6 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x) x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d) x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d)
new_shape = (x.shape[1]*x.shape[2], x.shape[3]) new_shape = (x.shape[1]*x.shape[2], x.shape[3])
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3]) new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3]) new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
@ -475,7 +474,6 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d) xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d) xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d) xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d) xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
@ -486,16 +484,18 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition) addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn) out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
out = BatchNormalization(name="bn9")(out) out = BatchNormalization(name="bn9")(out)
out = Activation("relu", name="relu9")(out) out = Activation("relu", name="relu9")(out)
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out) #out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
out = Dense(n_classes, activation="softmax", name="dense2")(out) out = Dense(n_classes, activation="softmax", name="dense2")(out)
if inference:
return Model(inputs, out)
# Add CTC layer for calculating CTC loss at each step. # Add CTC layer for calculating CTC loss at each step.
output = CTCLayer(name="ctc_loss")(labels, out) out = CTCLayer(name="ctc_loss")(labels, out)
model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer") return Model((inputs, labels), out)
return model