mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-12-23 11:34:16 +01:00
CNN–RNN–OCR inference and adaptation of the CNN–RNN–OCR model to support inference on both CPU and GPU
This commit is contained in:
parent
6ee79c7320
commit
49261fa99b
2 changed files with 61 additions and 28 deletions
|
|
@ -25,6 +25,9 @@ from .models import (
|
||||||
Patches
|
Patches
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from.utils import (scale_padd_image_for_ocr)
|
||||||
|
from eynollah.utils.utils_ocr import (decode_batch_predictions)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
|
|
@ -34,7 +37,7 @@ Tool to load model and predict for given image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class sbb_predict:
|
class sbb_predict:
|
||||||
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||||
self.image=image
|
self.image=image
|
||||||
self.dir_in=dir_in
|
self.dir_in=dir_in
|
||||||
self.patches=patches
|
self.patches=patches
|
||||||
|
|
@ -46,6 +49,7 @@ class sbb_predict:
|
||||||
self.config_params_model=config_params_model
|
self.config_params_model=config_params_model
|
||||||
self.xml_file = xml_file
|
self.xml_file = xml_file
|
||||||
self.out = out
|
self.out = out
|
||||||
|
self.cpu = cpu
|
||||||
if min_area:
|
if min_area:
|
||||||
self.min_area = float(min_area)
|
self.min_area = float(min_area)
|
||||||
else:
|
else:
|
||||||
|
|
@ -157,25 +161,21 @@ class sbb_predict:
|
||||||
return mIoU
|
return mIoU
|
||||||
|
|
||||||
def start_new_session_and_model(self):
|
def start_new_session_and_model(self):
|
||||||
|
if self.task == "cnn-rnn-ocr":
|
||||||
|
if self.cpu:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES']='-1'
|
||||||
|
self.model = load_model(self.model_dir)
|
||||||
|
self.model = tf.keras.models.Model(
|
||||||
|
self.model.get_layer(name = "image").input,
|
||||||
|
self.model.get_layer(name = "dense2").output)
|
||||||
|
else:
|
||||||
config = tf.compat.v1.ConfigProto()
|
config = tf.compat.v1.ConfigProto()
|
||||||
config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||||
tensorflow_backend.set_session(session)
|
tensorflow_backend.set_session(session)
|
||||||
#tensorflow.keras.layers.custom_layer = PatchEncoder
|
|
||||||
#tensorflow.keras.layers.custom_layer = Patches
|
|
||||||
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
#config = tf.ConfigProto()
|
|
||||||
#config.gpu_options.allow_growth=True
|
|
||||||
|
|
||||||
#self.session = tf.InteractiveSession()
|
|
||||||
#keras.losses.custom_loss = self.weighted_categorical_crossentropy
|
|
||||||
#self.model = load_model(self.model_dir , compile=False)
|
|
||||||
|
|
||||||
|
|
||||||
##if self.weights_dir!=None:
|
|
||||||
##self.model.load_weights(self.weights_dir)
|
|
||||||
|
|
||||||
if self.task != 'classification' and self.task != 'reading_order':
|
if self.task != 'classification' and self.task != 'reading_order':
|
||||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||||
|
|
@ -244,6 +244,30 @@ class sbb_predict:
|
||||||
index_class = np.argmax(label_p_pred[0])
|
index_class = np.argmax(label_p_pred[0])
|
||||||
|
|
||||||
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||||
|
elif self.task == "cnn-rnn-ocr":
|
||||||
|
img=cv2.imread(image_dir)
|
||||||
|
img = scale_padd_image_for_ocr(img, self.config_params_model['input_height'], self.config_params_model['input_width'])
|
||||||
|
|
||||||
|
img = img / 255.
|
||||||
|
|
||||||
|
with open(os.path.join(self.model_dir, "characters_org.txt"), 'r') as char_txt_f:
|
||||||
|
characters = json.load(char_txt_f)
|
||||||
|
|
||||||
|
AUTOTUNE = tf.data.AUTOTUNE
|
||||||
|
|
||||||
|
# Mapping characters to integers.
|
||||||
|
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
||||||
|
|
||||||
|
# Mapping integers back to original characters.
|
||||||
|
num_to_char = StringLookup(
|
||||||
|
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
||||||
|
)
|
||||||
|
preds = self.model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]), verbose=0)
|
||||||
|
pred_texts = decode_batch_predictions(preds, num_to_char)
|
||||||
|
pred_texts = pred_texts[0].replace("[UNK]", "")
|
||||||
|
return pred_texts
|
||||||
|
|
||||||
|
|
||||||
elif self.task == 'reading_order':
|
elif self.task == 'reading_order':
|
||||||
img_height = self.config_params_model['input_height']
|
img_height = self.config_params_model['input_height']
|
||||||
img_width = self.config_params_model['input_width']
|
img_width = self.config_params_model['input_width']
|
||||||
|
|
@ -569,6 +593,8 @@ class sbb_predict:
|
||||||
elif self.task == 'enhancement':
|
elif self.task == 'enhancement':
|
||||||
if self.save:
|
if self.save:
|
||||||
cv2.imwrite(self.save,res)
|
cv2.imwrite(self.save,res)
|
||||||
|
elif self.task == "cnn-rnn-ocr":
|
||||||
|
print(f"Detected text: {res}")
|
||||||
else:
|
else:
|
||||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
if self.save:
|
if self.save:
|
||||||
|
|
@ -592,6 +618,8 @@ class sbb_predict:
|
||||||
elif self.task == 'enhancement':
|
elif self.task == 'enhancement':
|
||||||
self.save = os.path.join(self.out, f_name+'.png')
|
self.save = os.path.join(self.out, f_name+'.png')
|
||||||
cv2.imwrite(self.save,res)
|
cv2.imwrite(self.save,res)
|
||||||
|
elif self.task == "cnn-rnn-ocr":
|
||||||
|
print(f"Detected text for file name {f_name} is: {res}")
|
||||||
else:
|
else:
|
||||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
self.save = os.path.join(self.out, f_name+'_overlayed.png')
|
self.save = os.path.join(self.out, f_name+'_overlayed.png')
|
||||||
|
|
@ -657,24 +685,29 @@ class sbb_predict:
|
||||||
"-xml",
|
"-xml",
|
||||||
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
|
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--cpu",
|
||||||
|
"-cpu",
|
||||||
|
help="For OCR, the default device is the GPU. If this parameter is set to true, inference will be performed on the CPU",
|
||||||
|
is_flag=True,
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--min_area",
|
"--min_area",
|
||||||
"-min",
|
"-min",
|
||||||
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
||||||
)
|
)
|
||||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||||
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||||
with open(os.path.join(model,'config.json')) as f:
|
with open(os.path.join(model,'config.json')) as f:
|
||||||
config_params_model = json.load(f)
|
config_params_model = json.load(f)
|
||||||
task = config_params_model['task']
|
task = config_params_model['task']
|
||||||
if task != 'classification' and task != 'reading_order':
|
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr":
|
||||||
if image and not save:
|
if image and not save:
|
||||||
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if dir_in and not out:
|
if dir_in and not out:
|
||||||
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
|
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
|
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area)
|
||||||
x.run()
|
x.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -843,7 +843,7 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
||||||
|
|
||||||
addition_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
addition_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||||
|
|
||||||
out = tf.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
out = tf.keras.layers.Conv1D(max_seq, 1, data_format="channels_last")(addition_rnn)
|
||||||
out = tf.keras.layers.BatchNormalization(name="bn9")(out)
|
out = tf.keras.layers.BatchNormalization(name="bn9")(out)
|
||||||
out = tf.keras.layers.Activation("relu", name="relu9")(out)
|
out = tf.keras.layers.Activation("relu", name="relu9")(out)
|
||||||
#out = tf.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
#out = tf.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue