mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-27 23:39:15 +02:00
training.models for cnn-rnn-ocr: avoid Conv1D(..channels_first..)
This commit is contained in:
parent
dfa651ef8a
commit
eb4cae9dee
1 changed files with 6 additions and 3 deletions
|
|
@ -21,6 +21,7 @@ from tensorflow.keras.layers import (
|
|||
LSTM,
|
||||
MaxPooling2D,
|
||||
MultiHeadAttention,
|
||||
Permute,
|
||||
Reshape,
|
||||
UpSampling2D,
|
||||
ZeroPadding2D,
|
||||
|
|
@ -70,7 +71,7 @@ class CTCDecoder(Layer):
|
|||
## but Keras greedy sometimes removes arbitrary letters
|
||||
# outputs, logits = tf.keras.backend.ctc_decode(inputs,
|
||||
# lengths,
|
||||
# beam_width=20
|
||||
# beam_width=20,
|
||||
# greedy=False, # True,
|
||||
# # backend does not allow these kwargs
|
||||
# #merge_repeated=False,
|
||||
|
|
@ -530,10 +531,12 @@ def cnn_rnn_ocr_model(input_height=None, input_width=None, n_classes=None, max_l
|
|||
|
||||
addition_rnn = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(addition)
|
||||
|
||||
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
||||
#out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
||||
out = Permute((2, 1))(addition_rnn)
|
||||
out = Conv1D(max_len, 1, data_format="channels_last")(out)
|
||||
out = Permute((2, 1))(out)
|
||||
out = BatchNormalization(name="bn9")(out)
|
||||
out = Activation("relu", name="relu9")(out)
|
||||
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||
|
||||
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue