mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
training.models for cnn-rnn-ocr: fix config names for height/width…
- rename `image_height` → `input_height` - rename `image_width` → `input_width`
This commit is contained in:
parent
4181e03bc9
commit
9d2412080f
1 changed files with 11 additions and 11 deletions
|
|
@ -92,16 +92,16 @@ class CTCDecoder(Layer):
|
||||||
# get top path for all sequences in batch
|
# get top path for all sequences in batch
|
||||||
decoded = decoded[0]
|
decoded = decoded[0]
|
||||||
logits = logits[:, 0] - logits[:, 1]
|
logits = logits[:, 0] - logits[:, 1]
|
||||||
|
probs = tf.exp(-logits)
|
||||||
# convert to dense
|
# convert to dense
|
||||||
outputs = tf.SparseTensor(decoded.indices, decoded.values,
|
outputs = tf.SparseTensor(decoded.indices, decoded.values,
|
||||||
(n_samples, n_steps))
|
(n_samples, n_steps))
|
||||||
outputs = tf.sparse.to_dense(sp_input=outputs, default_value=-1)
|
outputs = tf.sparse.to_dense(sp_input=outputs, default_value=n_classes-1)
|
||||||
# # drop non-tokens (-1) and OOV (0)
|
# # drop non-tokens (-1) and OOV (0)
|
||||||
# result = []
|
# result = []
|
||||||
# for output in outputs:
|
# for output in outputs:
|
||||||
# result.append(tf.gather(output, tf.where(output > 0)))
|
# result.append(tf.gather(output, tf.where(output > 0)))
|
||||||
# outputs = tf.stack(result)
|
# outputs = tf.stack(result)
|
||||||
probs = tf.exp(-logits)
|
|
||||||
return outputs, probs
|
return outputs, probs
|
||||||
|
|
||||||
def mlp(x, hidden_units, dropout_rate):
|
def mlp(x, hidden_units, dropout_rate):
|
||||||
|
|
@ -468,8 +468,8 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None):
|
def cnn_rnn_ocr_model(input_height=None, input_width=None, n_classes=None, max_len=None, inference=False, characters_txt_file=None):
|
||||||
inputs = Input(shape=(image_height, image_width, 3), name="image")
|
inputs = Input(shape=(input_height, input_width, 3), name="image")
|
||||||
labels = Input(name="label", shape=(None,))
|
labels = Input(name="label", shape=(None,))
|
||||||
|
|
||||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs)
|
x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs)
|
||||||
|
|
@ -496,10 +496,10 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
||||||
x = Activation("relu", name="relu6")(x)
|
x = Activation("relu", name="relu6")(x)
|
||||||
x = MaxPooling2D(pool_size=(2,2),strides=(2,2))(x)
|
x = MaxPooling2D(pool_size=(2,2),strides=(2,2))(x)
|
||||||
|
|
||||||
x = Conv2D(image_width,kernel_size=(3,3),padding="same")(x)
|
x = Conv2D(input_width,kernel_size=(3,3),padding="same")(x)
|
||||||
x = BatchNormalization(name="bn7")(x)
|
x = BatchNormalization(name="bn7")(x)
|
||||||
x = Activation("relu", name="relu7")(x)
|
x = Activation("relu", name="relu7")(x)
|
||||||
x = Conv2D(image_width,kernel_size=(16,1))(x)
|
x = Conv2D(input_width,kernel_size=(16,1))(x)
|
||||||
x = BatchNormalization(name="bn8")(x)
|
x = BatchNormalization(name="bn8")(x)
|
||||||
x = Activation("relu", name="relu8")(x)
|
x = Activation("relu", name="relu8")(x)
|
||||||
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
||||||
|
|
@ -513,9 +513,9 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
||||||
x2d = Reshape(new_shape2, name="reshape2")(x2d)
|
x2d = Reshape(new_shape2, name="reshape2")(x2d)
|
||||||
x4d = Reshape(new_shape4, name="reshape4")(x4d)
|
x4d = Reshape(new_shape4, name="reshape4")(x4d)
|
||||||
|
|
||||||
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
xrnnorg = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(x)
|
||||||
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
xrnn2d = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(x2d)
|
||||||
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
xrnn4d = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(x4d)
|
||||||
|
|
||||||
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||||
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||||
|
|
@ -528,7 +528,7 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
||||||
|
|
||||||
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
||||||
|
|
||||||
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
addition_rnn = Bidirectional(LSTM(input_width, return_sequences=True, dropout=0.25))(addition)
|
||||||
|
|
||||||
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
||||||
out = BatchNormalization(name="bn9")(out)
|
out = BatchNormalization(name="bn9")(out)
|
||||||
|
|
@ -539,7 +539,7 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_l
|
||||||
|
|
||||||
if inference:
|
if inference:
|
||||||
# add second path for binarization
|
# add second path for binarization
|
||||||
inputs_bin = Input(shape=(image_height, image_width, 3), name="image_bin")
|
inputs_bin = Input(shape=(input_height, input_width, 3), name="image_bin")
|
||||||
out_bin = Model(inputs, out)(inputs_bin)
|
out_bin = Model(inputs, out)(inputs_bin)
|
||||||
# ensemble raw results
|
# ensemble raw results
|
||||||
out = 0.5 * (out + out_bin)
|
out = 0.5 * (out + out_bin)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue