mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-01-29 13:46:58 +01:00
Fix filename stem extraction using binarization. Restore the CNN-RNN model to its previous version, as setting channels_last alone was insufficient for running on both CPU and GPU. Prevent errors caused by null values in image shape elements.
This commit is contained in:
parent
30f39e7383
commit
6ae244bf9b
3 changed files with 9 additions and 5 deletions
|
|
@ -19,7 +19,7 @@ from eynollah.model_zoo import EynollahModelZoo
|
||||||
tf_disable_interactive_logs()
|
tf_disable_interactive_logs()
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.keras import backend as tensorflow_backend
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
|
from pathlib import Path
|
||||||
from .utils import is_image_filename
|
from .utils import is_image_filename
|
||||||
|
|
||||||
def resize_image(img_in, input_height, input_width):
|
def resize_image(img_in, input_height, input_width):
|
||||||
|
|
@ -347,7 +347,7 @@ class SbbBinarizer:
|
||||||
self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in)
|
self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in)
|
||||||
for i, image_path in enumerate(ls_imgs):
|
for i, image_path in enumerate(ls_imgs):
|
||||||
self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path)
|
self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path)
|
||||||
image_stem = image_path.split('.')[0]
|
image_stem = Path(image_path).stem
|
||||||
image = cv2.imread(os.path.join(dir_in,image_path) )
|
image = cv2.imread(os.path.join(dir_in,image_path) )
|
||||||
img_last = 0
|
img_last = 0
|
||||||
model_file, model = self.models
|
model_file, model = self.models
|
||||||
|
|
|
||||||
|
|
@ -843,7 +843,7 @@ def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_s
|
||||||
|
|
||||||
addition_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
addition_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||||
|
|
||||||
out = tf.keras.layers.Conv1D(max_seq, 1, data_format="channels_last")(addition_rnn)
|
out = tf.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
||||||
out = tf.keras.layers.BatchNormalization(name="bn9")(out)
|
out = tf.keras.layers.BatchNormalization(name="bn9")(out)
|
||||||
out = tf.keras.layers.Activation("relu", name="relu9")(out)
|
out = tf.keras.layers.Activation("relu", name="relu9")(out)
|
||||||
#out = tf.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
#out = tf.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
from pathlib import Path
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
@ -32,6 +32,9 @@ def scale_padd_image_for_ocr(img, height, width):
|
||||||
else:
|
else:
|
||||||
width_new = width
|
width_new = width
|
||||||
|
|
||||||
|
if width_new <= 0:
|
||||||
|
width_new = width
|
||||||
|
|
||||||
img_res= resize_image (img, height, width_new)
|
img_res= resize_image (img, height, width_new)
|
||||||
img_fin = np.ones((height, width, 3))*255
|
img_fin = np.ones((height, width, 3))*255
|
||||||
|
|
||||||
|
|
@ -1304,7 +1307,8 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
while True:
|
while True:
|
||||||
for i in ls_files_images:
|
for i in ls_files_images:
|
||||||
f_name = i.split('.')[0]
|
print(i, 'i')
|
||||||
|
f_name = Path(i).stem#.split('.')[0]
|
||||||
|
|
||||||
txt_inp = open(os.path.join(dir_train, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]
|
txt_inp = open(os.path.join(dir_train, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue