mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
training.models for cnn-rnn-ocr: fix config names for height/width…
- rename `image_height` → `input_height` - rename `image_width` → `input_width`
This commit is contained in:
parent
4181e03bc9
commit
9d2412080f
1 changed files with 11 additions and 11 deletions
|
|
@ -92,16 +92,16 @@ class CTCDecoder(Layer):
|
|||
# get top path for all sequences in batch
|
||||
decoded = decoded[0]
|
||||
logits = logits[:, 0] - logits[:, 1]
|
||||
probs = tf.exp(-logits)
|
||||
# convert to dense
|
||||
outputs = tf.SparseTensor(decoded.indices, decoded.values,
|
||||
(n_samples, n_steps))
|
||||
outputs = tf.sparse.to_dense(sp_input=outputs, default_value=-1)
|
||||
outputs = tf.sparse.to_dense(sp_input=outputs, default_value=n_classes-1)
|
||||
# # drop non-tokens (-1) and OOV (0)
|
||||
# result = []
|
||||
# for output in outputs:
|
||||
# result.append(tf.gather(output, tf.where(output > 0)))
|
||||
# outputs = tf.stack(result)
|
||||
probs = tf.exp(-logits)
|
||||
return outputs, probs
|
||||
|
||||
def mlp(x, hidden_units, dropout_rate):
|
||||
|
|
@ -468,8 +468,8 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
|||
|
||||
return model
|
||||
|
||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None):
|
||||
inputs = Input(shape=(image_height, image_width, 3), name="image")
|
||||
def cnn_rnn_ocr_model(input_height=None, input_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None):
|
||||
inputs = Input(shape=(input_height, input_width, 3), name="image")
|
||||
labels = Input(name="label", shape=(None,))
|
||||
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs)
|
||||
|
|
@ -496,10 +496,10 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
|||
x = Activation("relu", name="relu6")(x)
|
||||
x = MaxPooling2D(pool_size=(2,2),strides=(2,2))(x)
|
||||
|
||||
x = Conv2D(image_width,kernel_size=(3,3),padding="same")(x)
|
||||
x = Conv2D(input_width,kernel_size=(3,3),padding="same")(x)
|
||||
x = BatchNormalization(name="bn7")(x)
|
||||
x = Activation("relu", name="relu7")(x)
|
||||
x = Conv2D(image_width,kernel_size=(16,1))(x)
|
||||
x = Conv2D(input_width,kernel_size=(16,1))(x)
|
||||
x = BatchNormalization(name="bn8")(x)
|
||||
x = Activation("relu", name="relu8")(x)
|
||||
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
|
|
@ -513,9 +513,9 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
|||
x2d = Reshape(new_shape2, name="reshape2")(x2d)
|
||||
x4d = Reshape(new_shape4, name="reshape4")(x4d)
|
||||
|
||||
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
||||
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
||||
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
||||
xrnnorg = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(x)
|
||||
xrnn2d = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(x2d)
|
||||
xrnn4d = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(x4d)
|
||||
|
||||
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||
|
|
@ -528,7 +528,7 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
|||
|
||||
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
||||
|
||||
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||
addition_rnn = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(addition)
|
||||
|
||||
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
||||
out = BatchNormalization(name="bn9")(out)
|
||||
|
|
@ -539,7 +539,7 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
|||
|
||||
if inference:
|
||||
# add second path for binarization
|
||||
inputs_bin = Input(shape=(image_height, image_width, 3), name="image_bin")
|
||||
inputs_bin = Input(shape=(input_height, input_width, 3), name="image_bin")
|
||||
out_bin = Model(inputs, out)(inputs_bin)
|
||||
# ensemble raw results
|
||||
out = 0.5 * (out + out_bin)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue