mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-31 01:59:27 +02:00
models.cnn_rnn_ocr_model: add inference option, drop model name
This commit is contained in:
parent
c4a7eec5b3
commit
faef1967f8
1 changed files with 17 additions and 17 deletions
|
|
@ -422,11 +422,11 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
|
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False):
|
||||||
input_img = Input(shape=(image_height, image_width, 3), name="image")
|
inputs = Input(shape=(image_height, image_width, 3), name="image")
|
||||||
labels = Input(name="label", shape=(None,))
|
labels = Input(name="label", shape=(None,))
|
||||||
|
|
||||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
|
x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs)
|
||||||
x = BatchNormalization(name="bn1")(x)
|
x = BatchNormalization(name="bn1")(x)
|
||||||
x = Activation("relu", name="relu1")(x)
|
x = Activation("relu", name="relu1")(x)
|
||||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
|
x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
|
||||||
|
|
@ -459,7 +459,6 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
||||||
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
||||||
x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d)
|
x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d)
|
||||||
|
|
||||||
|
|
||||||
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
|
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
|
||||||
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
|
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
|
||||||
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
|
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
|
||||||
|
|
@ -475,7 +474,6 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
||||||
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||||
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||||
|
|
||||||
|
|
||||||
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
||||||
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
||||||
|
|
||||||
|
|
@ -486,16 +484,18 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
||||||
|
|
||||||
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||||
|
|
||||||
out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
||||||
out = BatchNormalization(name="bn9")(out)
|
out = BatchNormalization(name="bn9")(out)
|
||||||
out = Activation("relu", name="relu9")(out)
|
out = Activation("relu", name="relu9")(out)
|
||||||
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||||
|
|
||||||
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
||||||
|
|
||||||
|
if inference:
|
||||||
|
return Model(inputs, out)
|
||||||
|
|
||||||
# Add CTC layer for calculating CTC loss at each step.
|
# Add CTC layer for calculating CTC loss at each step.
|
||||||
output = CTCLayer(name="ctc_loss")(labels, out)
|
out = CTCLayer(name="ctc_loss")(labels, out)
|
||||||
|
|
||||||
model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer")
|
return Model((inputs, labels), out)
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue