mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-27 23:39:15 +02:00
training.models for cnn-rnn-ocr: avoid Conv1D(..channels_first..)
This commit is contained in:
parent
dfa651ef8a
commit
eb4cae9dee
1 changed files with 6 additions and 3 deletions
|
|
@ -21,6 +21,7 @@ from tensorflow.keras.layers import (
|
||||||
LSTM,
|
LSTM,
|
||||||
MaxPooling2D,
|
MaxPooling2D,
|
||||||
MultiHeadAttention,
|
MultiHeadAttention,
|
||||||
|
Permute,
|
||||||
Reshape,
|
Reshape,
|
||||||
UpSampling2D,
|
UpSampling2D,
|
||||||
ZeroPadding2D,
|
ZeroPadding2D,
|
||||||
|
|
@ -70,7 +71,7 @@ class CTCDecoder(Layer):
|
||||||
## but Keras greedy sometimes removes arbitrary letters
|
## but Keras greedy sometimes removes arbitrary letters
|
||||||
# outputs, logits = tf.keras.backend.ctc_decode(inputs,
|
# outputs, logits = tf.keras.backend.ctc_decode(inputs,
|
||||||
# lengths,
|
# lengths,
|
||||||
# beam_width=20
|
# beam_width=20,
|
||||||
# greedy=False, # True,
|
# greedy=False, # True,
|
||||||
# # backend does not allow these kwargs
|
# # backend does not allow these kwargs
|
||||||
# #merge_repeated=False,
|
# #merge_repeated=False,
|
||||||
|
|
@ -530,10 +531,12 @@ def cnn_rnn_ocr_model(input_height=None, input_width=None, n_classes=None, max_l
|
||||||
|
|
||||||
addition_rnn = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(addition)
|
addition_rnn = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(addition)
|
||||||
|
|
||||||
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
#out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
||||||
|
out = Permute((2, 1))(addition_rnn)
|
||||||
|
out = Conv1D(max_len, 1, data_format="channels_last")(out)
|
||||||
|
out = Permute((2, 1))(out)
|
||||||
out = BatchNormalization(name="bn9")(out)
|
out = BatchNormalization(name="bn9")(out)
|
||||||
out = Activation("relu", name="relu9")(out)
|
out = Activation("relu", name="relu9")(out)
|
||||||
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
|
||||||
|
|
||||||
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue