mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training.train: fix data batching for OCR in 27f43c17
This commit is contained in:
parent
86b009bc31
commit
92fc2bd815
2 changed files with 63 additions and 65 deletions
|
|
@ -618,11 +618,6 @@ def run(_config,
|
||||||
padding_token = len(characters) + 5
|
padding_token = len(characters) + 5
|
||||||
# Mapping characters to integers.
|
# Mapping characters to integers.
|
||||||
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
||||||
|
|
||||||
# Mapping integers back to original characters.
|
|
||||||
##num_to_char = StringLookup(
|
|
||||||
##vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
|
||||||
##)
|
|
||||||
n_classes = len(char_to_num.get_vocabulary()) + 2
|
n_classes = len(char_to_num.get_vocabulary()) + 2
|
||||||
|
|
||||||
if continue_training:
|
if continue_training:
|
||||||
|
|
@ -649,21 +644,23 @@ def run(_config,
|
||||||
char_to_num=char_to_num,
|
char_to_num=char_to_num,
|
||||||
padding_token=padding_token
|
padding_token=padding_token
|
||||||
)
|
)
|
||||||
train_ds = tf.data.Dataset.from_generator(gen)
|
train_ds = (tf.data.Dataset.from_generator(gen, (tf.float32, tf.int64))
|
||||||
train_ds = train_ds.padded_batch(n_batch,
|
.padded_batch(n_batch,
|
||||||
padded_shapes=([input_height, input_width, 3], [None]),
|
padded_shapes=([input_height, input_width, 3], [None]),
|
||||||
padding_values=(0, padding_token),
|
padding_values=(None, tf.constant(padding_token, dtype=tf.int64)),
|
||||||
drop_remainder=True,
|
drop_remainder=True,
|
||||||
#num_parallel_calls=tf.data.AUTOTUNE,
|
#num_parallel_calls=tf.data.AUTOTUNE,
|
||||||
|
)
|
||||||
|
.map(lambda x, y: {"image": x, "label": y})
|
||||||
|
.prefetch(tf.data.AUTOTUNE)
|
||||||
)
|
)
|
||||||
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
|
|
||||||
|
|
||||||
#initial_learning_rate = 1e-4
|
#initial_learning_rate = 1e-4
|
||||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||||
#alpha = 0.01
|
#alpha = 0.01
|
||||||
#lr_schedule = 1e-4
|
#lr_schedule = 1e-4
|
||||||
#tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha)
|
#tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha)
|
||||||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
opt = Adam(learning_rate=learning_rate)
|
||||||
model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer
|
model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer
|
||||||
|
|
||||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||||
|
|
|
||||||
|
|
@ -1073,58 +1073,59 @@ def preprocess_img(img,
|
||||||
scaler=sc_ind)
|
scaler=sc_ind)
|
||||||
|
|
||||||
def preprocess_img_ocr(
|
def preprocess_img_ocr(
|
||||||
img,
|
img,
|
||||||
img_name,
|
img_name,
|
||||||
lab,
|
lab,
|
||||||
char_to_num=None,
|
char_to_num=None,
|
||||||
padding_token=-1,
|
padding_token=-1,
|
||||||
max_len=500,
|
max_len=500,
|
||||||
n_batch=1,
|
n_batch=1,
|
||||||
input_height=None,
|
input_height=None,
|
||||||
input_width=None,
|
input_width=None,
|
||||||
augmentation=False,
|
augmentation=False,
|
||||||
color_padding_rotation=None,
|
color_padding_rotation=None,
|
||||||
thetha_padd=None,
|
thetha_padd=None,
|
||||||
padd_colors=None,
|
padd_colors=None,
|
||||||
rotation_not_90=None,
|
rotation_not_90=None,
|
||||||
thetha=None,
|
thetha=None,
|
||||||
padding_white=None,
|
padding_white=None,
|
||||||
white_padds=None,
|
white_padds=None,
|
||||||
degrading=False,
|
degrading=False,
|
||||||
bin_deg=None,
|
bin_deg=None,
|
||||||
degrade_scales=None,
|
degrade_scales=None,
|
||||||
blur_aug=False,
|
blur_aug=False,
|
||||||
blur_k=None,
|
blur_k=None,
|
||||||
brightening=False,
|
brightening=False,
|
||||||
brightness=None,
|
brightness=None,
|
||||||
binarization=False,
|
binarization=False,
|
||||||
image_inversion=False,
|
image_inversion=False,
|
||||||
channels_shuffling=False,
|
channels_shuffling=False,
|
||||||
shuffle_indexes=None,
|
shuffle_indexes=None,
|
||||||
white_noise_strap=False,
|
white_noise_strap=False,
|
||||||
textline_skewing=False,
|
textline_skewing=False,
|
||||||
textline_skewing_bin=False,
|
textline_skewing_bin=False,
|
||||||
skewing_amplitudes=None,
|
skewing_amplitudes=None,
|
||||||
textline_left_in_depth=False,
|
textline_left_in_depth=False,
|
||||||
textline_left_in_depth_bin=False,
|
textline_left_in_depth_bin=False,
|
||||||
textline_right_in_depth=False,
|
textline_right_in_depth=False,
|
||||||
textline_right_in_depth_bin=False,
|
textline_right_in_depth_bin=False,
|
||||||
textline_up_in_depth=False,
|
textline_up_in_depth=False,
|
||||||
textline_up_in_depth_bin=False,
|
textline_up_in_depth_bin=False,
|
||||||
textline_down_in_depth=False,
|
textline_down_in_depth=False,
|
||||||
textline_down_in_depth_bin=False,
|
textline_down_in_depth_bin=False,
|
||||||
pepper_aug=False,
|
pepper_aug=False,
|
||||||
pepper_bin_aug=False,
|
pepper_bin_aug=False,
|
||||||
pepper_indexes=None,
|
pepper_indexes=None,
|
||||||
dir_img_bin=None,
|
dir_img_bin=None,
|
||||||
add_red_textlines=False,
|
add_red_textlines=False,
|
||||||
adding_rgb_background=False,
|
adding_rgb_background=False,
|
||||||
dir_rgb_backgrounds=None,
|
dir_rgb_backgrounds=None,
|
||||||
adding_rgb_foreground=False,
|
adding_rgb_foreground=False,
|
||||||
dir_rgb_foregrounds=None,
|
dir_rgb_foregrounds=None,
|
||||||
number_of_backgrounds_per_image=None,
|
number_of_backgrounds_per_image=None,
|
||||||
list_all_possible_background_images=None,
|
list_all_possible_background_images=None,
|
||||||
list_all_possible_foreground_rgbs=None,
|
list_all_possible_foreground_rgbs=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
def scale_image(img):
|
def scale_image(img):
|
||||||
return scale_padd_image_for_ocr(img, input_height, input_width).astype(np.float32) / 255.
|
return scale_padd_image_for_ocr(img, input_height, input_width).astype(np.float32) / 255.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue