training.models for cnn-rnn-ocr: avoid Conv1D(..channels_first..)

This commit is contained in:
Robert Sachunsky 2026-06-16 17:28:10 +02:00
parent dfa651ef8a
commit eb4cae9dee

View file

@ -21,6 +21,7 @@ from tensorflow.keras.layers import (
LSTM, LSTM,
MaxPooling2D, MaxPooling2D,
MultiHeadAttention, MultiHeadAttention,
Permute,
Reshape, Reshape,
UpSampling2D, UpSampling2D,
ZeroPadding2D, ZeroPadding2D,
@ -70,7 +71,7 @@ class CTCDecoder(Layer):
## but Keras greedy sometimes removes arbitrary letters ## but Keras greedy sometimes removes arbitrary letters
# outputs, logits = tf.keras.backend.ctc_decode(inputs, # outputs, logits = tf.keras.backend.ctc_decode(inputs,
# lengths, # lengths,
# beam_width=20 # beam_width=20,
# greedy=False, # True, # greedy=False, # True,
# # backend does not allow these kwargs # # backend does not allow these kwargs
# #merge_repeated=False, # #merge_repeated=False,
@ -530,10 +531,12 @@ def cnn_rnn_ocr_model(input_height=None, input_width=None, n_classes=None, max_l
addition_rnn = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(addition) addition_rnn = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(addition)
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn) #out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
out = Permute((2, 1))(addition_rnn)
out = Conv1D(max_len, 1, data_format="channels_last")(out)
out = Permute((2, 1))(out)
out = BatchNormalization(name="bn9")(out) out = BatchNormalization(name="bn9")(out)
out = Activation("relu", name="relu9")(out) out = Activation("relu", name="relu9")(out)
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
out = Dense(n_classes, activation="softmax", name="dense2")(out) out = Dense(n_classes, activation="softmax", name="dense2")(out)