The cnn-rnn ocr model can be trained now

This commit is contained in:
vahidrezanezhad 2025-12-09 17:22:12 +01:00
parent 84a72a128b
commit 4fc3ff33cb
3 changed files with 138 additions and 57 deletions

View file

@ -138,8 +138,10 @@ def config_params():
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
thetha = None # Rotate image by these angles for augmentation.
thetha_padd = None # List of angles used for rotation alongside padding
shuffle_indexes = None # List of shuffling indexes like [[0,2,1], [1,2,0], [1,0,2]]
pepper_indexes = None # List of pepper noise indexes like [0.01, 0.005]
white_padds = None # List of padding size in the case of white padding
skewing_amplitudes = None # List of skewing augmentation amplitudes like [5, 8]
blur_k = None # Blur image for augmentation.
scales = None # Scale patches for augmentation.
@ -181,14 +183,14 @@ def run(_config, n_classes, n_epochs, input_height,
brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes,
brightness, dir_train, data_is_provided, scaling_bluring,
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
thetha, scaling_flip, continue_training, transformer_projection_dim,
thetha, thetha_padd, scaling_flip, continue_training, transformer_projection_dim,
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
transformer_patchsize_x, transformer_patchsize_y,
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds,
dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin,
textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin,
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors, pepper_indexes, skewing_amplitudes, max_len):
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors, pepper_indexes, white_padds, skewing_amplitudes, max_len):
if dir_rgb_backgrounds:
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
@ -433,7 +435,7 @@ def run(_config, n_classes, n_epochs, input_height,
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes)
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds)
len_dataset = aug_multip*len(ls_files_images)
@ -442,10 +444,41 @@ def run(_config, n_classes, n_epochs, input_height,
adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth,
textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin,
pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors,
shuffle_indexes, pepper_indexes, skewing_amplitudes, dir_img_bin)
pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors,
shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, char_to_num, list_all_possible_background_images, list_all_possible_foreground_rgbs,
dir_rgb_backgrounds, dir_rgb_foregrounds, white_padds, dir_img_bin)
print(len_dataset, 'len_dataset')
initial_learning_rate = 1e-4
decay_steps = int (n_epochs * ( len_dataset / n_batch ))
alpha = 0.01
lr_schedule = 1e-4#tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha)
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
model.compile(optimizer=opt)
if save_interval:
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
indexer_start = 0
for i in range(n_epochs):
if save_interval:
model.fit(
train_ds,
steps_per_epoch=len_dataset / n_batch,
epochs=1,
callbacks=[save_weights_callback]
)
else:
model.fit(
train_ds,
steps_per_epoch=len_dataset / n_batch,
epochs=1
)
if i >=0:
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
elif task=='classification':
configuration()

View file

@ -9,11 +9,35 @@ from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
from tqdm import tqdm
import imutils
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from PIL import Image, ImageFile, ImageEnhance
ImageFile.LOAD_TRUNCATED_IMAGES = True
def vectorize_label(label, char_to_num, padding_token, max_len):
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
length = tf.shape(label)[0]
pad_amount = max_len - length
label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
return label
def scale_padd_image_for_ocr(img, height, width):
ratio = height /float(img.shape[0])
w_ratio = int(ratio * img.shape[1])
if w_ratio<=width:
width_new = w_ratio
else:
width_new = width
img_res= resize_image (img, height, width_new)
img_fin = np.ones((height, width, 3))*255
img_fin[:,:width_new,:] = img_res[:,:,:]
return img_fin
def add_salt_and_pepper_noise(img, salt_prob, pepper_prob):
"""
Add salt-and-pepper noise to an image.
@ -1269,8 +1293,9 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth,
textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin,
pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors,
shuffle_indexes, pepper_indexes, skewing_amplitudes, dir_img_bin=None):
pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors,
shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, char_to_num, list_all_possible_background_images,
list_all_possible_foreground_rgbs, dir_rgb_backgrounds, dir_rgb_foregrounds, white_padds, dir_img_bin=None):
random.shuffle(ls_files_images)
@ -1294,7 +1319,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1306,14 +1331,14 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
batchcount = 0
if color_padding_rotation:
for index, thetha in enumerate(thetha_padd):
for index, thetha_ind in enumerate(thetha_padd):
for padd_col in padd_colors:
img_out = rotation_not_90_func(do_padding(img, 1.2, padd_col), thetha)
img_out = rotation_not_90_func_single_image(do_padding_for_ocr(img, 1.2, padd_col), thetha_ind)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1325,12 +1350,12 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
batchcount = 0
if rotation_not_90:
for index, thetha in enumerate(thetha):
img_out = rotation_not_90_func(img, thetha)
for index, thetha_ind in enumerate(thetha):
img_out = rotation_not_90_func_single_image(img, thetha_ind)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1342,12 +1367,12 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
batchcount = 0
if blur_aug:
for index, blur_type in enumerate(blurs):
for index, blur_type in enumerate(blur_k):
img_out = bluring(img, blur_type)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1367,7 +1392,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1388,7 +1413,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1409,7 +1434,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1423,11 +1448,11 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
if padding_white:
for index, padding_size in enumerate(white_padds):
for padd_col in padd_colors:
img_out = do_padding(img, padding_size, padd_col)
img_out = do_padding_for_ocr(img, padding_size, padd_col)
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1451,7 +1476,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1472,7 +1497,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1487,7 +1512,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_bin_corr, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1503,7 +1528,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :, :, :] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1511,7 +1536,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
ret_x = ret_x/255.
yield {"image": ret_x, "label": ret_y}
ret_x = np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
ret_y = np.zeros((batch_size, max_len)).astype(np.int16)+padding_token
ret_y = np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
batchcount = 0
@ -1521,7 +1546,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1539,7 +1564,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_red_context, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1556,7 +1581,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1576,7 +1601,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1596,7 +1621,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img_bin_corr)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1616,7 +1641,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1636,7 +1661,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img_bin_corr)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1656,7 +1681,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1676,7 +1701,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img_bin_corr)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1696,7 +1721,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1716,7 +1741,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img_bin_corr)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1736,7 +1761,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1756,7 +1781,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = np.copy(img_bin_corr)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1773,7 +1798,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1791,7 +1816,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1809,7 +1834,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
img_out = scale_padd_image_for_ocr(img, input_height, input_width)
ret_x[batchcount, :,:,:] = img_out[:,:,:]
ret_y[batchcount, :] = vectorize_label(txt_inp)
ret_y[batchcount, :] = vectorize_label(txt_inp, char_to_num, padding_token, max_len)
batchcount+=1
@ -1823,7 +1848,7 @@ def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir
def return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug,
degrading, bin_deg, brightening, padding_white,adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes):
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, thetha_padd, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes, blur_k, white_padds):
aug_multip = 1
if augmentation:
@ -1870,7 +1895,7 @@ def return_multiplier_based_on_augmnentations(augmentation, color_padding_rotati
if channels_shuffling:
aug_multip = aug_multip + len(shuffle_indexes)
if blur_aug:
aug_multip = aug_multip + len(blurs)
aug_multip = aug_multip + len(blur_k)
if brightening:
aug_multip = aug_multip + len(brightness)
if padding_white: