mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-12-15 15:44:13 +01:00
cnn-rnn model can be called - model input height and width are dynamic now - data generator is also callable
This commit is contained in:
parent
59e5a73654
commit
84a72a128b
4 changed files with 283 additions and 239 deletions
|
|
@ -13,6 +13,25 @@ resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_
|
||||||
IMAGE_ORDERING = 'channels_last'
|
IMAGE_ORDERING = 'channels_last'
|
||||||
MERGE_AXIS = -1
|
MERGE_AXIS = -1
|
||||||
|
|
||||||
|
|
||||||
|
class CTCLayer(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, name=None):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self.loss_fn = tf.keras.backend.ctc_batch_cost
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
|
||||||
|
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
|
||||||
|
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
|
||||||
|
|
||||||
|
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
|
||||||
|
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
|
||||||
|
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
|
||||||
|
self.add_loss(loss)
|
||||||
|
|
||||||
|
# At test time, just return the computed predictions.
|
||||||
|
return y_pred
|
||||||
|
|
||||||
def mlp(x, hidden_units, dropout_rate):
|
def mlp(x, hidden_units, dropout_rate):
|
||||||
for units in hidden_units:
|
for units in hidden_units:
|
||||||
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
||||||
|
|
@ -759,85 +778,84 @@ def machine_based_reading_order_model(n_classes,input_height=224,input_width=224
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def cnn_rnn_ocr_model(image_height, image_width, n_classes, max_seq):
|
def cnn_rnn_ocr_model(image_height=None, image_width=None, n_classes=None, max_seq=None):
|
||||||
input_img = tensorflow.keras.Input(shape=(image_height, image_width, 3), name="image")
|
input_img = tf.keras.Input(shape=(image_height, image_width, 3), name="image")
|
||||||
labels = tensorflow.keras.layers.Input(name="label", shape=(None,))
|
labels = tf.keras.layers.Input(name="label", shape=(None,))
|
||||||
|
|
||||||
x = tensorflow.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
|
x = tf.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(input_img)
|
||||||
x = tensorflow.keras.layers.BatchNormalization(name="bn1")(x)
|
x = tf.keras.layers.BatchNormalization(name="bn1")(x)
|
||||||
x = tensorflow.keras.layers.Activation("relu", name="relu1")(x)
|
x = tf.keras.layers.Activation("relu", name="relu1")(x)
|
||||||
x = tensorflow.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(x)
|
x = tf.keras.layers.Conv2D(64,kernel_size=(3,3),padding="same")(x)
|
||||||
x = tensorflow.keras.layers.BatchNormalization(name="bn2")(x)
|
x = tf.keras.layers.BatchNormalization(name="bn2")(x)
|
||||||
x = tensorflow.keras.layers.Activation("relu", name="relu2")(x)
|
x = tf.keras.layers.Activation("relu", name="relu2")(x)
|
||||||
x = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
x = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
||||||
|
|
||||||
x = tensorflow.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
|
x = tf.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
|
||||||
x = tensorflow.keras.layers.BatchNormalization(name="bn3")(x)
|
x = tf.keras.layers.BatchNormalization(name="bn3")(x)
|
||||||
x = tensorflow.keras.layers.Activation("relu", name="relu3")(x)
|
x = tf.keras.layers.Activation("relu", name="relu3")(x)
|
||||||
x = tensorflow.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
|
x = tf.keras.layers.Conv2D(128,kernel_size=(3,3),padding="same")(x)
|
||||||
x = tensorflow.keras.layers.BatchNormalization(name="bn4")(x)
|
x = tf.keras.layers.BatchNormalization(name="bn4")(x)
|
||||||
x = tensorflow.keras.layers.Activation("relu", name="relu4")(x)
|
x = tf.keras.layers.Activation("relu", name="relu4")(x)
|
||||||
x = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
x = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
||||||
|
|
||||||
x = tensorflow.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
|
x = tf.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
|
||||||
x = tensorflow.keras.layers.BatchNormalization(name="bn5")(x)
|
x = tf.keras.layers.BatchNormalization(name="bn5")(x)
|
||||||
x = tensorflow.keras.layers.Activation("relu", name="relu5")(x)
|
x = tf.keras.layers.Activation("relu", name="relu5")(x)
|
||||||
x = tensorflow.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
|
x = tf.keras.layers.Conv2D(256,kernel_size=(3,3),padding="same")(x)
|
||||||
x = tensorflow.keras.layers.BatchNormalization(name="bn6")(x)
|
x = tf.keras.layers.BatchNormalization(name="bn6")(x)
|
||||||
x = tensorflow.keras.layers.Activation("relu", name="relu6")(x)
|
x = tf.keras.layers.Activation("relu", name="relu6")(x)
|
||||||
x = tensorflow.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x)
|
x = tf.keras.layers.MaxPool2D(pool_size=(2,2),strides=(2,2))(x)
|
||||||
|
|
||||||
x = tensorflow.keras.layers.Conv2D(512,kernel_size=(3,3),padding="same")(x)
|
x = tf.keras.layers.Conv2D(image_width,kernel_size=(3,3),padding="same")(x)
|
||||||
x = tensorflow.keras.layers.BatchNormalization(name="bn7")(x)
|
x = tf.keras.layers.BatchNormalization(name="bn7")(x)
|
||||||
x = tensorflow.keras.layers.Activation("relu", name="relu7")(x)
|
x = tf.keras.layers.Activation("relu", name="relu7")(x)
|
||||||
x = tensorflow.keras.layers.Conv2D(512,kernel_size=(16,1))(x)
|
x = tf.keras.layers.Conv2D(image_width,kernel_size=(16,1))(x)
|
||||||
x = tensorflow.keras.layers.BatchNormalization(name="bn8")(x)
|
x = tf.keras.layers.BatchNormalization(name="bn8")(x)
|
||||||
x = tensorflow.keras.layers.Activation("relu", name="relu8")(x)
|
x = tf.keras.layers.Activation("relu", name="relu8")(x)
|
||||||
x2d = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
x2d = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x)
|
||||||
x4d = tensorflow.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d)
|
x4d = tf.keras.layers.MaxPool2D(pool_size=(1,2),strides=(1,2))(x2d)
|
||||||
|
|
||||||
|
|
||||||
new_shape = (x.shape[2], x.shape[3])
|
new_shape = (x.shape[1]*x.shape[2], x.shape[3])
|
||||||
new_shape2 = (x2d.shape[2], x2d.shape[3])
|
new_shape2 = (x2d.shape[1]*x2d.shape[2], x2d.shape[3])
|
||||||
new_shape4 = (x4d.shape[2], x4d.shape[3])
|
new_shape4 = (x4d.shape[1]*x4d.shape[2], x4d.shape[3])
|
||||||
|
|
||||||
|
x = tf.keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
|
||||||
|
x2d = tf.keras.layers.Reshape(target_shape=new_shape2, name="reshape2")(x2d)
|
||||||
|
x4d = tf.keras.layers.Reshape(target_shape=new_shape4, name="reshape4")(x4d)
|
||||||
|
|
||||||
|
|
||||||
x = tensorflow.keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
|
xrnnorg = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x)
|
||||||
x2d = tensorflow.keras.layers.Reshape(target_shape=new_shape2, name="reshape2")(x2d)
|
xrnn2d = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x2d)
|
||||||
x4d = tensorflow.keras.layers.Reshape(target_shape=new_shape4, name="reshape4")(x4d)
|
xrnn4d = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(x4d)
|
||||||
|
|
||||||
|
xrnn2d = tf.keras.layers.Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
||||||
|
xrnn4d = tf.keras.layers.Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
||||||
|
|
||||||
|
|
||||||
xrnnorg = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x)
|
xrnn2dup = tf.keras.layers.UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
||||||
xrnn2d = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x2d)
|
xrnn4dup = tf.keras.layers.UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
||||||
xrnn4d = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(x4d)
|
|
||||||
|
|
||||||
xrnn2d = tensorflow.keras.layers.Reshape(target_shape=(1, xrnn2d.shape[1], xrnn2d.shape[2]), name="reshape6")(xrnn2d)
|
|
||||||
xrnn4d = tensorflow.keras.layers.Reshape(target_shape=(1, xrnn4d.shape[1], xrnn4d.shape[2]), name="reshape8")(xrnn4d)
|
|
||||||
|
|
||||||
|
xrnn2dup = tf.keras.layers.Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
||||||
|
xrnn4dup = tf.keras.layers.Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
||||||
|
|
||||||
xrnn2dup = tensorflow.keras.layers.UpSampling2D(size=(1, 2), interpolation="nearest")(xrnn2d)
|
addition = tf.keras.layers.Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
||||||
xrnn4dup = tensorflow.keras.layers.UpSampling2D(size=(1, 4), interpolation="nearest")(xrnn4d)
|
|
||||||
|
|
||||||
xrnn2dup = tensorflow.keras.layers.Reshape(target_shape=(xrnn2dup.shape[2], xrnn2dup.shape[3]), name="reshape10")(xrnn2dup)
|
addition_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(image_width, return_sequences=True, dropout=0.25))(addition)
|
||||||
xrnn4dup = tensorflow.keras.layers.Reshape(target_shape=(xrnn4dup.shape[2], xrnn4dup.shape[3]), name="reshape12")(xrnn4dup)
|
|
||||||
|
out = tf.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
||||||
|
out = tf.keras.layers.BatchNormalization(name="bn9")(out)
|
||||||
|
out = tf.keras.layers.Activation("relu", name="relu9")(out)
|
||||||
|
#out = tf.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
||||||
|
|
||||||
addition = tensorflow.keras.layers.Add()([xrnnorg, xrnn2dup, xrnn4dup])
|
out = tf.keras.layers.Dense(
|
||||||
|
|
||||||
addition_rnn = tensorflow.keras.layers.Bidirectional(tensorflow.keras.layers.LSTM(512, return_sequences=True, dropout=0.25))(addition)
|
|
||||||
|
|
||||||
out = tensorflow.keras.layers.Conv1D(max_seq, 1, data_format="channels_first")(addition_rnn)
|
|
||||||
out = tensorflow.keras.layers.BatchNormalization(name="bn9")(out)
|
|
||||||
out = tensorflow.keras.layers.Activation("relu", name="relu9")(out)
|
|
||||||
#out = tensorflow.keras.layers.Conv1D(n_classes, 1, activation='relu', data_format="channels_last")(out)
|
|
||||||
|
|
||||||
out = tensorflow.keras.layers.Dense(
|
|
||||||
n_classes, activation="softmax", name="dense2"
|
n_classes, activation="softmax", name="dense2"
|
||||||
)(out)
|
)(out)
|
||||||
|
|
||||||
# Add CTC layer for calculating CTC loss at each step.
|
# Add CTC layer for calculating CTC loss at each step.
|
||||||
output = CTCLayer(name="ctc_loss")(labels, out)
|
output = CTCLayer(name="ctc_loss")(labels, out)
|
||||||
|
|
||||||
model = tensorflow.keras.models.Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer")
|
model = tf.keras.models.Model(inputs=[input_img, labels], outputs=output, name="handwriting_recognizer")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,13 @@ from eynollah.training.models import (
|
||||||
resnet50_classifier,
|
resnet50_classifier,
|
||||||
resnet50_unet,
|
resnet50_unet,
|
||||||
vit_resnet50_unet,
|
vit_resnet50_unet,
|
||||||
vit_resnet50_unet_transformer_before_cnn
|
vit_resnet50_unet_transformer_before_cnn,
|
||||||
|
cnn_rnn_ocr_model
|
||||||
)
|
)
|
||||||
from eynollah.training.utils import (
|
from eynollah.training.utils import (
|
||||||
data_gen,
|
data_gen,
|
||||||
|
data_gen_ocr,
|
||||||
|
return_multiplier_based_on_augmnentations,
|
||||||
generate_arrays_from_folder_reading_order,
|
generate_arrays_from_folder_reading_order,
|
||||||
generate_data_from_folder_evaluation,
|
generate_data_from_folder_evaluation,
|
||||||
generate_data_from_folder_training,
|
generate_data_from_folder_training,
|
||||||
|
|
@ -36,6 +39,7 @@ from tensorflow.keras.models import load_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sklearn.metrics import f1_score
|
from sklearn.metrics import f1_score
|
||||||
from tensorflow.keras.callbacks import Callback
|
from tensorflow.keras.callbacks import Callback
|
||||||
|
from tensorflow.keras.layers import StringLookup
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
@ -62,6 +66,7 @@ class SaveWeightsAfterSteps(Callback):
|
||||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def configuration():
|
def configuration():
|
||||||
config = tf.compat.v1.ConfigProto()
|
config = tf.compat.v1.ConfigProto()
|
||||||
config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
|
|
@ -89,6 +94,7 @@ def config_params():
|
||||||
input_width = 224 * 1 # Width of model's input in pixels.
|
input_width = 224 * 1 # Width of model's input in pixels.
|
||||||
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
|
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
|
||||||
n_batch = 1 # Number of batches at each iteration.
|
n_batch = 1 # Number of batches at each iteration.
|
||||||
|
max_len = None # max len for ocr output.
|
||||||
learning_rate = 1e-4 # Set the learning rate.
|
learning_rate = 1e-4 # Set the learning rate.
|
||||||
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
|
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
|
||||||
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
|
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
|
||||||
|
|
@ -132,7 +138,9 @@ def config_params():
|
||||||
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
|
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
|
||||||
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
|
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
|
||||||
thetha = None # Rotate image by these angles for augmentation.
|
thetha = None # Rotate image by these angles for augmentation.
|
||||||
shuffle_indexes = None
|
shuffle_indexes = None # List of shuffling indexes like [[0,2,1], [1,2,0], [1,0,2]]
|
||||||
|
pepper_indexes = None # List of pepper noise indexes like [0.01, 0.005]
|
||||||
|
skewing_amplitudes = None # List of skewing augmentation amplitudes like [5, 8]
|
||||||
blur_k = None # Blur image for augmentation.
|
blur_k = None # Blur image for augmentation.
|
||||||
scales = None # Scale patches for augmentation.
|
scales = None # Scale patches for augmentation.
|
||||||
padd_colors = None # padding colors. A list elements can be only white and black. like ["white", "black"] or only one of them ["white"]
|
padd_colors = None # padding colors. A list elements can be only white and black. like ["white", "black"] or only one of them ["white"]
|
||||||
|
|
@ -180,7 +188,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds,
|
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds,
|
||||||
dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin,
|
dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin,
|
||||||
textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin,
|
textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin,
|
||||||
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors):
|
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors, pepper_indexes, skewing_amplitudes, max_len):
|
||||||
|
|
||||||
if dir_rgb_backgrounds:
|
if dir_rgb_backgrounds:
|
||||||
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
||||||
|
|
@ -409,20 +417,35 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
||||||
|
|
||||||
# Mapping integers back to original characters.
|
# Mapping integers back to original characters.
|
||||||
num_to_char = StringLookup(
|
##num_to_char = StringLookup(
|
||||||
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
##vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
||||||
)
|
##)
|
||||||
|
|
||||||
padding_token = len(characters) + 5
|
padding_token = len(characters) + 5
|
||||||
ls_files_images = os.listdir(dir_img)
|
ls_files_images = os.listdir(dir_img)
|
||||||
|
|
||||||
train_ds = data_gen_ocr(padding_token, batchsize=n_batch, height=input_height, width=input_width, max_len=max_len, dir_ins=dir_train, ls_files_images,
|
n_classes = len(char_to_num.get_vocabulary()) + 2
|
||||||
augmentation, color_padding_rotation, rotation=rotation_not_90, bluring_aug=blurring, degrading, bin_deg, brightening, w_padding=padding_white,
|
|
||||||
rgb_fging=adding_rgb_foreground, rgb_bkding=adding_rgb_background, binarization, image_inversion, channel_shuffling=channels_shuffling, add_red_textline=add_red_textlines, white_noise_strap,
|
|
||||||
|
model = cnn_rnn_ocr_model(image_height=input_height, image_width=input_width, n_classes=n_classes, max_seq=max_len)
|
||||||
|
|
||||||
|
print(model.summary())
|
||||||
|
|
||||||
|
aug_multip = return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg,
|
||||||
|
brightening, padding_white, adding_rgb_foreground, adding_rgb_background, binarization,
|
||||||
|
image_inversion, channels_shuffling, add_red_textlines, white_noise_strap, textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes)
|
||||||
|
|
||||||
|
len_dataset = aug_multip*len(ls_files_images)
|
||||||
|
|
||||||
|
train_ds = data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir_train, ls_files_images,
|
||||||
|
augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg, brightening, padding_white,
|
||||||
|
adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
|
||||||
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth,
|
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth,
|
||||||
textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin,
|
textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin,
|
||||||
pepper_bin_aug, pepper_aug, deg_scales=degrade_scales, number_of_backgrounds_per_image, thethas=thetha, brightness, padd_colors,
|
pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors,
|
||||||
shuffle_indexes, pepper_indexes, skewing_amplitudes)
|
shuffle_indexes, pepper_indexes, skewing_amplitudes, dir_img_bin)
|
||||||
|
|
||||||
|
print(len_dataset, 'len_dataset')
|
||||||
|
|
||||||
elif task=='classification':
|
elif task=='classification':
|
||||||
configuration()
|
configuration()
|
||||||
|
|
|
||||||
|
|
@ -414,7 +414,7 @@ def generate_data_from_folder_evaluation(path_classes, height, width, n_classes,
|
||||||
|
|
||||||
return ret_x/255., ret_y
|
return ret_x/255., ret_y
|
||||||
|
|
||||||
def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes, list_classes):
|
def generate_data_from_folder_training(path_classes, n_batch, height, width, n_classes, list_classes):
|
||||||
#sub_classes = os.listdir(path_classes)
|
#sub_classes = os.listdir(path_classes)
|
||||||
#n_classes = len(sub_classes)
|
#n_classes = len(sub_classes)
|
||||||
|
|
||||||
|
|
@ -440,8 +440,8 @@ def generate_data_from_folder_training(path_classes, batchsize, height, width, n
|
||||||
shuffled_labels = np.array(labels)[ids]
|
shuffled_labels = np.array(labels)[ids]
|
||||||
shuffled_files = np.array(all_imgs)[ids]
|
shuffled_files = np.array(all_imgs)[ids]
|
||||||
categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ]
|
categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ]
|
||||||
ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16)
|
ret_x= np.zeros((n_batch, height,width, 3)).astype(np.int16)
|
||||||
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
while True:
|
while True:
|
||||||
for i in range(len(shuffled_files)):
|
for i in range(len(shuffled_files)):
|
||||||
|
|
@ -465,11 +465,11 @@ def generate_data_from_folder_training(path_classes, batchsize, height, width, n
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield ret_x, ret_y
|
yield ret_x, ret_y
|
||||||
ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16)
|
ret_x= np.zeros((n_batch, height,width, 3)).astype(np.int16)
|
||||||
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
def do_brightening(img_in_dir, factor):
|
def do_brightening(img_in_dir, factor):
|
||||||
|
|
@ -634,10 +634,10 @@ def IoU(Yi, y_predi):
|
||||||
#print("Mean IoU: {:4.3f}".format(mIoU))
|
#print("Mean IoU: {:4.3f}".format(mIoU))
|
||||||
return mIoU
|
return mIoU
|
||||||
|
|
||||||
def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batchsize, height, width, n_classes, thetha, augmentation=False):
|
def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, n_batch, height, width, n_classes, thetha, augmentation=False):
|
||||||
all_labels_files = os.listdir(classes_file_dir)
|
all_labels_files = os.listdir(classes_file_dir)
|
||||||
ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16)
|
ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16)
|
||||||
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
while True:
|
while True:
|
||||||
for i in all_labels_files:
|
for i in all_labels_files:
|
||||||
|
|
@ -652,10 +652,10 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch
|
||||||
|
|
||||||
ret_y[batchcount, :] = label_class
|
ret_y[batchcount, :] = label_class
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
yield ret_x, ret_y
|
yield ret_x, ret_y
|
||||||
ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16)
|
ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16)
|
||||||
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if augmentation:
|
if augmentation:
|
||||||
|
|
@ -670,10 +670,10 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, batch
|
||||||
|
|
||||||
ret_y[batchcount, :] = label_class
|
ret_y[batchcount, :] = label_class
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
yield ret_x, ret_y
|
yield ret_x, ret_y
|
||||||
ret_x= np.zeros((batchsize, height, width, 3))#.astype(np.int16)
|
ret_x= np.zeros((n_batch, height, width, 3))#.astype(np.int16)
|
||||||
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
|
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
|
||||||
|
|
@ -1264,42 +1264,45 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len=None, dir_ins=None, ls_files_images=None,
|
def data_gen_ocr(padding_token, n_batch, input_height, input_width, max_len, dir_train, ls_files_images,
|
||||||
augmentation=False, color_padding_rotation=False, rotation=False, bluring_aug=False, degrading=False, bin_deg=False, brightening=False, w_padding=False,
|
augmentation, color_padding_rotation, rotation_not_90, blur_aug, degrading, bin_deg, brightening, padding_white,
|
||||||
rgb_fging=False, rgb_bkding=False, binarization=False, image_inversion=False, channel_shuffling=False, add_red_textline=False, white_noise_strap=False,
|
adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
|
||||||
textline_skewing=False, textline_skewing_bin=False, textline_left_in_depth=False, textline_left_in_depth_bin=False, textline_right_in_depth=False,
|
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth,
|
||||||
textline_right_in_depth_bin=False, textline_up_in_depth=False, textline_up_in_depth_bin=False, textline_down_in_depth=False, textline_down_in_depth_bin=False,
|
textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin,
|
||||||
pepper_bin_aug=False, pepper_aug=False, deg_scales=None, number_of_backgrounds_per_image=None, thethas=None, brightness=None, padd_colors=None,
|
pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors,
|
||||||
shuffle_indexes=None, ):
|
shuffle_indexes, pepper_indexes, skewing_amplitudes, dir_img_bin=None):
|
||||||
|
|
||||||
random.shuffle(ls_files_images)
|
random.shuffle(ls_files_images)
|
||||||
|
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
while True:
|
while True:
|
||||||
for i in ls_files_images:
|
for i in ls_files_images:
|
||||||
f_name = i.split('.')[0]
|
f_name = i.split('.')[0]
|
||||||
|
|
||||||
txt_inp = open(os.path.join(dir_ins, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]
|
txt_inp = open(os.path.join(dir_train, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]
|
||||||
|
|
||||||
img = cv2.imread(os.path.join(dir_ins, "images/"+i) )
|
img = cv2.imread(os.path.join(dir_train, "images/"+i) )
|
||||||
img_bin_corr = cv2.imread(os.path.join(dir_ins, "images_bin/"+f_name+'.png') )
|
if dir_img_bin:
|
||||||
|
img_bin_corr = cv2.imread(os.path.join(dir_img_bin, f_name+'.png') )
|
||||||
|
else:
|
||||||
|
img_bin_corr = None
|
||||||
|
|
||||||
|
|
||||||
if augmentation:
|
if augmentation:
|
||||||
img_out = scale_padd_image_for_ocr(img, height, width)
|
img_out = scale_padd_image_for_ocr(img, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if color_padding_rotation:
|
if color_padding_rotation:
|
||||||
|
|
@ -1307,59 +1310,59 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
for padd_col in padd_colors:
|
for padd_col in padd_colors:
|
||||||
img_out = rotation_not_90_func(do_padding(img, 1.2, padd_col), thetha)
|
img_out = rotation_not_90_func(do_padding(img, 1.2, padd_col), thetha)
|
||||||
|
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if rotation:
|
if rotation_not_90:
|
||||||
for index, thetha in enumerate(thethas):
|
for index, thetha in enumerate(thetha):
|
||||||
img_out = rotation_not_90_func(img, thetha)
|
img_out = rotation_not_90_func(img, thetha)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if bluring_aug:
|
if blur_aug:
|
||||||
for index, blur_type in enumerate(blurs):
|
for index, blur_type in enumerate(blurs):
|
||||||
img_out = bluring(img, blur_type)
|
img_out = bluring(img, blur_type)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if degrading:
|
if degrading:
|
||||||
for index, deg_scale_ind in enumerate(deg_scales):
|
for index, deg_scale_ind in enumerate(degrade_scales):
|
||||||
try:
|
try:
|
||||||
img_out = do_degrading(img, deg_scale_ind)
|
img_out = do_degrading(img, deg_scale_ind)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img)
|
img_out = np.copy(img)
|
||||||
|
|
||||||
|
|
@ -1368,32 +1371,32 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if bin_deg:
|
if bin_deg:
|
||||||
for index, deg_scale_ind in enumerate(deg_scales):
|
for index, deg_scale_ind in enumerate(degrade_scales):
|
||||||
try:
|
try:
|
||||||
img_out = do_degrading(img_bin_corr, deg_scale_ind)
|
img_out = do_degrading(img_bin_corr, deg_scale_ind)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img_bin_corr)
|
img_out = np.copy(img_bin_corr)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1403,39 +1406,39 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
img_out = do_brightening(dir_img, bright_scale_ind)
|
img_out = do_brightening(dir_img, bright_scale_ind)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img)
|
img_out = np.copy(img)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if w_padding:
|
if padding_white:
|
||||||
for index, padding_size in enumerate(white_padds):
|
for index, padding_size in enumerate(white_padds):
|
||||||
for padd_col in padd_colors:
|
for padd_col in padd_colors:
|
||||||
img_out = do_padding(img, padding_size, padd_col)
|
img_out = do_padding(img, padding_size, padd_col)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if rgb_fging:
|
if adding_rgb_foreground:
|
||||||
for i_n in range(number_of_backgrounds_per_image):
|
for i_n in range(number_of_backgrounds_per_image):
|
||||||
background_image_chosen_name = random.choice(list_all_possible_background_images)
|
background_image_chosen_name = random.choice(list_all_possible_background_images)
|
||||||
foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs)
|
foreground_rgb_chosen_name = random.choice(list_all_possible_foreground_rgbs)
|
||||||
|
|
@ -1445,130 +1448,130 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
img_with_overlayed_background = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
|
img_with_overlayed_background = return_binary_image_with_given_rgb_background_and_given_foreground_rgb(img_bin_corr, img_rgb_background_chosen, foreground_rgb_chosen)
|
||||||
|
|
||||||
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, height, width)
|
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if rgb_bkding:
|
if adding_rgb_background:
|
||||||
for i_n in range(number_of_backgrounds_per_image):
|
for i_n in range(number_of_backgrounds_per_image):
|
||||||
background_image_chosen_name = random.choice(list_all_possible_background_images)
|
background_image_chosen_name = random.choice(list_all_possible_background_images)
|
||||||
img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
img_rgb_background_chosen = cv2.imread(dir_rgb_backgrounds + '/' + background_image_chosen_name)
|
||||||
img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
|
img_with_overlayed_background = return_binary_image_with_given_rgb_background(img_bin_corr, img_rgb_background_chosen)
|
||||||
|
|
||||||
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, height, width)
|
img_out = scale_padd_image_for_ocr(img_with_overlayed_background, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if binarization:
|
if binarization:
|
||||||
img_out = scale_padd_image_for_ocr(img_bin_corr, height, width)
|
img_out = scale_padd_image_for_ocr(img_bin_corr, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if image_inversion:
|
if image_inversion:
|
||||||
img_out = invert_image(img_bin_corr)
|
img_out = invert_image(img_bin_corr)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :, :, :] = img_out[:,:,:]
|
ret_x[batchcount, :, :, :] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x = np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x = np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y = np.zeros((batch_size, max_len)).astype(np.int16)+padding_token
|
ret_y = np.zeros((batch_size, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if channel_shuffling:
|
if channels_shuffling:
|
||||||
for shuffle_index in shuffle_indexes:
|
for shuffle_index in shuffle_indexes:
|
||||||
img_out = return_shuffled_channels(img, shuffle_index)
|
img_out = return_shuffled_channels(img, shuffle_index)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if add_red_textline:
|
if add_red_textlines:
|
||||||
img_red_context = return_image_with_red_elements(img, img_bin_corr)
|
img_red_context = return_image_with_red_elements(img, img_bin_corr)
|
||||||
|
|
||||||
img_out = scale_padd_image_for_ocr(img_red_context, height, width)
|
img_out = scale_padd_image_for_ocr(img_red_context, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if white_noise_strap:
|
if white_noise_strap:
|
||||||
img_out = return_image_with_strapped_white_noises(img)
|
img_out = return_image_with_strapped_white_noises(img)
|
||||||
|
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if textline_skewing:
|
if textline_skewing:
|
||||||
for index, des_scale_ind in enumerate(skewing_amplitudes):
|
for index, des_scale_ind in enumerate(skewing_amplitudes):
|
||||||
try:
|
try:
|
||||||
img_out = do_deskewing(img, des_scale_ind)
|
img_out = do_deskewing(img, des_scale_ind)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img)
|
img_out = np.copy(img)
|
||||||
|
|
||||||
|
|
@ -1577,18 +1580,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if textline_skewing_bin:
|
if textline_skewing_bin:
|
||||||
for index, des_scale_ind in enumerate(skewing_amplitudes):
|
for index, des_scale_ind in enumerate(skewing_amplitudes):
|
||||||
try:
|
try:
|
||||||
img_out = do_deskewing(img_bin_corr, des_scale_ind)
|
img_out = do_deskewing(img_bin_corr, des_scale_ind)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img_bin_corr)
|
img_out = np.copy(img_bin_corr)
|
||||||
|
|
||||||
|
|
@ -1597,18 +1600,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if textline_left_in_depth:
|
if textline_left_in_depth:
|
||||||
try:
|
try:
|
||||||
img_out = do_left_in_depth(img)
|
img_out = do_left_in_depth(img)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img)
|
img_out = np.copy(img)
|
||||||
|
|
||||||
|
|
@ -1617,18 +1620,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if textline_left_in_depth_bin:
|
if textline_left_in_depth_bin:
|
||||||
try:
|
try:
|
||||||
img_out = do_left_in_depth(img_bin_corr)
|
img_out = do_left_in_depth(img_bin_corr)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img_bin_corr)
|
img_out = np.copy(img_bin_corr)
|
||||||
|
|
||||||
|
|
@ -1637,18 +1640,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if textline_right_in_depth:
|
if textline_right_in_depth:
|
||||||
try:
|
try:
|
||||||
img_out = do_right_in_depth(img)
|
img_out = do_right_in_depth(img)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img)
|
img_out = np.copy(img)
|
||||||
|
|
||||||
|
|
@ -1657,18 +1660,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if textline_right_in_depth_bin:
|
if textline_right_in_depth_bin:
|
||||||
try:
|
try:
|
||||||
img_out = do_right_in_depth(img_bin_corr)
|
img_out = do_right_in_depth(img_bin_corr)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img_bin_corr)
|
img_out = np.copy(img_bin_corr)
|
||||||
|
|
||||||
|
|
@ -1677,18 +1680,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if textline_up_in_depth:
|
if textline_up_in_depth:
|
||||||
try:
|
try:
|
||||||
img_out = do_up_in_depth(img)
|
img_out = do_up_in_depth(img)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img)
|
img_out = np.copy(img)
|
||||||
|
|
||||||
|
|
@ -1697,18 +1700,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if textline_up_in_depth_bin:
|
if textline_up_in_depth_bin:
|
||||||
try:
|
try:
|
||||||
img_out = do_up_in_depth(img_bin_corr)
|
img_out = do_up_in_depth(img_bin_corr)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img_bin_corr)
|
img_out = np.copy(img_bin_corr)
|
||||||
|
|
||||||
|
|
@ -1717,18 +1720,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if textline_down_in_depth:
|
if textline_down_in_depth:
|
||||||
try:
|
try:
|
||||||
img_out = do_down_in_depth(img)
|
img_out = do_down_in_depth(img)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img)
|
img_out = np.copy(img)
|
||||||
|
|
||||||
|
|
@ -1737,18 +1740,18 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if textline_down_in_depth_bin:
|
if textline_down_in_depth_bin:
|
||||||
try:
|
try:
|
||||||
img_out = do_down_in_depth(img_bin_corr)
|
img_out = do_down_in_depth(img_bin_corr)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
except:
|
except:
|
||||||
img_out = np.copy(img_bin_corr)
|
img_out = np.copy(img_bin_corr)
|
||||||
|
|
||||||
|
|
@ -1757,70 +1760,70 @@ def data_gen_ocr(padding_token, batchsize=None, height=None, width=None, max_len
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
if pepper_bin_aug:
|
if pepper_bin_aug:
|
||||||
for index, pepper_ind in enumerate(pepper_indexes):
|
for index, pepper_ind in enumerate(pepper_indexes):
|
||||||
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
|
img_out = add_salt_and_pepper_noise(img_bin_corr, pepper_ind, pepper_ind)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
if pepper_aug:
|
if pepper_aug:
|
||||||
for index, pepper_ind in enumerate(pepper_indexes):
|
for index, pepper_ind in enumerate(pepper_indexes):
|
||||||
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
|
img_out = add_salt_and_pepper_noise(img, pepper_ind, pepper_ind)
|
||||||
img_out = scale_padd_image_for_ocr(img_out, height, width)
|
img_out = scale_padd_image_for_ocr(img_out, input_height, input_width)
|
||||||
|
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
img_out = scale_padd_image_for_ocr(img, height, width)
|
img_out = scale_padd_image_for_ocr(img, input_height, input_width)
|
||||||
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
ret_x[batchcount, :,:,:] = img_out[:,:,:]
|
||||||
|
|
||||||
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
ret_y[batchcount, :] = vectorize_label(txt_inp)
|
||||||
|
|
||||||
batchcount+=1
|
batchcount+=1
|
||||||
|
|
||||||
if batchcount>=batchsize:
|
if batchcount>=n_batch:
|
||||||
ret_x = ret_x/255.
|
ret_x = ret_x/255.
|
||||||
yield {"image": ret_x, "label": ret_y}
|
yield {"image": ret_x, "label": ret_y}
|
||||||
ret_x= np.zeros((batchsize, height, width, 3)).astype(np.float32)
|
ret_x= np.zeros((n_batch, input_height, input_width, 3)).astype(np.float32)
|
||||||
ret_y= np.zeros((batchsize, max_len)).astype(np.int16)+padding_token
|
ret_y= np.zeros((n_batch, max_len)).astype(np.int16)+padding_token
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
|
|
||||||
def return_muliplier_based_on_augmnentations(augmentation=False, color_padding_rotation=False, rotation=False, bluring_aug=False,
|
def return_multiplier_based_on_augmnentations(augmentation, color_padding_rotation, rotation_not_90, blur_aug,
|
||||||
degrading=False, bin_deg=False, brightening=False, w_padding=False,rgb_fging=False, rgb_bkding=False, binarization=False, image_inversion=False, channel_shuffling=False, add_red_textline=False, white_noise_strap=False,
|
degrading, bin_deg, brightening, padding_white,adding_rgb_foreground, adding_rgb_background, binarization, image_inversion, channels_shuffling, add_red_textlines, white_noise_strap,
|
||||||
textline_skewing=False, textline_skewing_bin=False, textline_left_in_depth=False, textline_left_in_depth_bin=False, textline_right_in_depth=False, textline_right_in_depth_bin=False, textline_up_in_depth=False, textline_up_in_depth_bin=False, textline_down_in_depth=False, textline_down_in_depth_bin=False, pepper_bin_aug=False, pepper_aug=False, deg_scales=None, number_of_backgrounds_per_image=None, thethas=None, brightness=None, padd_colors=None):
|
textline_skewing, textline_skewing_bin, textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, degrade_scales, number_of_backgrounds_per_image, thetha, brightness, padd_colors, shuffle_indexes, pepper_indexes, skewing_amplitudes):
|
||||||
aug_multip = 1
|
aug_multip = 1
|
||||||
|
|
||||||
if augmentation:
|
if augmentation:
|
||||||
|
|
@ -1828,7 +1831,7 @@ def return_muliplier_based_on_augmnentations(augmentation=False, color_padding_r
|
||||||
aug_multip = aug_multip + 1
|
aug_multip = aug_multip + 1
|
||||||
if image_inversion:
|
if image_inversion:
|
||||||
aug_multip = aug_multip + 1
|
aug_multip = aug_multip + 1
|
||||||
if add_red_textline:
|
if add_red_textlines:
|
||||||
aug_multip = aug_multip + 1
|
aug_multip = aug_multip + 1
|
||||||
if white_noise_strap:
|
if white_noise_strap:
|
||||||
aug_multip = aug_multip + 1
|
aug_multip = aug_multip + 1
|
||||||
|
|
@ -1848,29 +1851,29 @@ def return_muliplier_based_on_augmnentations(augmentation=False, color_padding_r
|
||||||
aug_multip = aug_multip + 1
|
aug_multip = aug_multip + 1
|
||||||
if textline_down_in_depth_bin:
|
if textline_down_in_depth_bin:
|
||||||
aug_multip = aug_multip + 1
|
aug_multip = aug_multip + 1
|
||||||
if rgb_fging:
|
if adding_rgb_foreground:
|
||||||
aug_multip = aug_multip + number_of_backgrounds_per_image
|
aug_multip = aug_multip + number_of_backgrounds_per_image
|
||||||
if rgb_bkding:
|
if adding_rgb_background:
|
||||||
aug_multip = aug_multip + number_of_backgrounds_per_image
|
aug_multip = aug_multip + number_of_backgrounds_per_image
|
||||||
if bin_deg:
|
if bin_deg:
|
||||||
aug_multip = aug_multip + len(deg_scales)
|
aug_multip = aug_multip + len(degrade_scales)
|
||||||
if degrading:
|
if degrading:
|
||||||
aug_multip = aug_multip + len(deg_scales)
|
aug_multip = aug_multip + len(degrade_scales)
|
||||||
if rotation:
|
if rotation_not_90:
|
||||||
aug_multip = aug_multip + len(thethas)
|
aug_multip = aug_multip + len(thetha)
|
||||||
if textline_skewing:
|
if textline_skewing:
|
||||||
aug_multip = aug_multip + len(skewing_amplitudes)
|
aug_multip = aug_multip + len(skewing_amplitudes)
|
||||||
if textline_skewing_bin:
|
if textline_skewing_bin:
|
||||||
aug_multip = aug_multip + len(skewing_amplitudes)
|
aug_multip = aug_multip + len(skewing_amplitudes)
|
||||||
if color_padding_rotation:
|
if color_padding_rotation:
|
||||||
aug_multip = aug_multip + len(thetha_padd)*len(padd_colors)
|
aug_multip = aug_multip + len(thetha_padd)*len(padd_colors)
|
||||||
if channel_shuffling:
|
if channels_shuffling:
|
||||||
aug_multip = aug_multip + len(shuffle_indexes)
|
aug_multip = aug_multip + len(shuffle_indexes)
|
||||||
if bluring_aug:
|
if blur_aug:
|
||||||
aug_multip = aug_multip + len(blurs)
|
aug_multip = aug_multip + len(blurs)
|
||||||
if brightening:
|
if brightening:
|
||||||
aug_multip = aug_multip + len(brightness)
|
aug_multip = aug_multip + len(brightness)
|
||||||
if w_padding:
|
if padding_white:
|
||||||
aug_multip = aug_multip + len(white_padds)*len(padd_colors)
|
aug_multip = aug_multip + len(white_padds)*len(padd_colors)
|
||||||
if pepper_aug:
|
if pepper_aug:
|
||||||
aug_multip = aug_multip + len(pepper_indexes)
|
aug_multip = aug_multip + len(pepper_indexes)
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
{
|
{
|
||||||
"backbone_type" : "transformer",
|
"backbone_type" : "transformer",
|
||||||
"task": "segmentation",
|
"task": "cnn-rnn-ocr",
|
||||||
"n_classes" : 2,
|
"n_classes" : 2,
|
||||||
"max_len": 280,
|
"max_len": 280,
|
||||||
"n_epochs" : 0,
|
"n_epochs" : 0,
|
||||||
"input_height" : 448,
|
"input_height" : 32,
|
||||||
"input_width" : 448,
|
"input_width" : 512,
|
||||||
"weight_decay" : 1e-6,
|
"weight_decay" : 1e-6,
|
||||||
"n_batch" : 1,
|
"n_batch" : 1,
|
||||||
"learning_rate": 1e-4,
|
"learning_rate": 1e-4,
|
||||||
"patches" : false,
|
"patches" : false,
|
||||||
"pretraining" : true,
|
"pretraining" : true,
|
||||||
"augmentation" : true,
|
"augmentation" : false,
|
||||||
"flip_aug" : false,
|
"flip_aug" : false,
|
||||||
"blur_aug" : false,
|
"blur_aug" : false,
|
||||||
"scaling" : false,
|
"scaling" : false,
|
||||||
|
|
@ -49,12 +49,12 @@
|
||||||
"weighted_loss": false,
|
"weighted_loss": false,
|
||||||
"is_loss_soft_dice": false,
|
"is_loss_soft_dice": false,
|
||||||
"data_is_provided": false,
|
"data_is_provided": false,
|
||||||
"dir_train": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new",
|
"dir_train": "/home/vahid/extracted_lines/1919_bin/train",
|
||||||
"dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new",
|
"dir_eval": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/eval_new",
|
||||||
"dir_output": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/output_new",
|
"dir_output": "/home/vahid/extracted_lines/1919_bin/output",
|
||||||
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
"dir_rgb_backgrounds": "/home/vahid/Documents/1_2_test_eynollah/set_rgb_background",
|
||||||
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
"dir_rgb_foregrounds": "/home/vahid/Documents/1_2_test_eynollah/out_set_rgb_foreground",
|
||||||
"dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin",
|
"dir_img_bin": "/home/vahid/Documents/test/sbb_pixelwise_segmentation/test_label/pageextractor_test/train_new/images_bin",
|
||||||
"characters_txt_file":"dir_of_characters_txt_file_for_ocr"
|
"characters_txt_file":"/home/vahid/Downloads/models_eynollah/model_eynollah_ocr_cnnrnn_20250930/characters_org.txt"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue