mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-31 01:59:27 +02:00
models.cnn_rnn_ocr_model: add inference option, drop model name
This commit is contained in:
parent
c4a7eec5b3
commit
faef1967f8
1 changed files with 17 additions and 17 deletions
|
|
@ -422,11 +422,11 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
|||
|
||||
return model
|
||||
|
||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
|
||||
input_img = Input(shape=(image_height, image_width, 3), name="image")
|
||||
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_len=None, inference=False):
|
||||
inputs = Input(shape=(image_height, image_width, 3), name="image")
|
||||
labels = Input(name="label", shape=(None,))
|
||||
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(inputs)
|
||||
x = BatchNormalization(name="bn1")(x)
|
||||
x = Activation("relu", name="relu1")(x)
|
||||
x = Conv2D(64,kernel_size=(3,3),padding="same")(x)
|
||||
|
|
@ -458,44 +458,44 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
|||
x = Activation("relu", name="relu8")(x)
|
||||
x2d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x)
|
||||
x4d = MaxPooling2D(pool_size=(1,2),strides=(1,2))(x2d)
|
||||
|
||||
|
||||
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
|
||||
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
|
||||
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
|
||||
|
||||
|
||||
x = Reshape(new_shape, name="reshape")(x)
|
||||
x2d = Reshape(new_shape2, name="reshape2")(x2d)
|
||||
x4d = Reshape(new_shape4, name="reshape4")(x4d)
|
||||
|
||||
|
||||
xrnnorg = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
||||
xrnn2d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
||||
xrnn4d = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
||||
|
||||
|
||||
xrnn2d = Reshape((1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||
xrnn4d = Reshape((1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||
|
||||
|
||||
xrnn2dup = UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
||||
xrnn4dup = UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
||||
|
||||
|
||||
xrnn2dup = Reshape((xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
||||
xrnn4dup = Reshape((xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
||||
|
||||
addition = Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
||||
|
||||
|
||||
addition_rnn = Bidirectional(LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||
|
||||
out = Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
||||
|
||||
out = Conv1D(max_len, 1, data_format="channels_first")(addition_rnn)
|
||||
out = BatchNormalization(name="bn9")(out)
|
||||
out = Activation("relu", name="relu9")(out)
|
||||
#out = Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||
|
||||
out = Dense(n_classes, activation="softmax", name="dense2")(out)
|
||||
|
||||
# Add CTC layer for calculating CTC loss at each step.
|
||||
output = CTCLayer(name="ctc_loss")(labels, out)
|
||||
|
||||
model = Model(inputs=(input_img, labels), outputs=output, name="handwriting_recognizer")
|
||||
if inference:
|
||||
return Model(inputs, out)
|
||||
|
||||
# Add CTC layer for calculating CTC loss at each step.
|
||||
out = CTCLayer(name="ctc_loss")(labels, out)
|
||||
|
||||
return Model((inputs, labels), out)
|
||||
|
||||
return model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue