training.train: fix data batching for OCR in 27f43c17

This commit is contained in:
Robert Sachunsky 2026-02-24 20:42:08 +01:00
parent 86b009bc31
commit 92fc2bd815
2 changed files with 63 additions and 65 deletions

View file

@ -618,11 +618,6 @@ def run(_config,
padding_token = len(characters) + 5 padding_token = len(characters) + 5
# Mapping characters to integers. # Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
##num_to_char = StringLookup(
##vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
##)
n_classes = len(char_to_num.get_vocabulary()) + 2 n_classes = len(char_to_num.get_vocabulary()) + 2
if continue_training: if continue_training:
@ -649,21 +644,23 @@ def run(_config,
char_to_num=char_to_num, char_to_num=char_to_num,
padding_token=padding_token padding_token=padding_token
) )
train_ds = tf.data.Dataset.from_generator(gen) train_ds = (tf.data.Dataset.from_generator(gen, (tf.float32, tf.int64))
train_ds = train_ds.padded_batch(n_batch, .padded_batch(n_batch,
padded_shapes=([input_height, input_width, 3], [None]), padded_shapes=([input_height, input_width, 3], [None]),
padding_values=(0, padding_token), padding_values=(None, tf.constant(padding_token, dtype=tf.int64)),
drop_remainder=True, drop_remainder=True,
#num_parallel_calls=tf.data.AUTOTUNE, #num_parallel_calls=tf.data.AUTOTUNE,
)
.map(lambda x, y: {"image": x, "label": y})
.prefetch(tf.data.AUTOTUNE)
) )
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
#initial_learning_rate = 1e-4 #initial_learning_rate = 1e-4
#decay_steps = int (n_epochs * ( len_dataset / n_batch )) #decay_steps = int (n_epochs * ( len_dataset / n_batch ))
#alpha = 0.01 #alpha = 0.01
#lr_schedule = 1e-4 #lr_schedule = 1e-4
#tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha) #tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha)
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) opt = Adam(learning_rate=learning_rate)
model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),

View file

@ -1073,58 +1073,59 @@ def preprocess_img(img,
scaler=sc_ind) scaler=sc_ind)
def preprocess_img_ocr( def preprocess_img_ocr(
img, img,
img_name, img_name,
lab, lab,
char_to_num=None, char_to_num=None,
padding_token=-1, padding_token=-1,
max_len=500, max_len=500,
n_batch=1, n_batch=1,
input_height=None, input_height=None,
input_width=None, input_width=None,
augmentation=False, augmentation=False,
color_padding_rotation=None, color_padding_rotation=None,
thetha_padd=None, thetha_padd=None,
padd_colors=None, padd_colors=None,
rotation_not_90=None, rotation_not_90=None,
thetha=None, thetha=None,
padding_white=None, padding_white=None,
white_padds=None, white_padds=None,
degrading=False, degrading=False,
bin_deg=None, bin_deg=None,
degrade_scales=None, degrade_scales=None,
blur_aug=False, blur_aug=False,
blur_k=None, blur_k=None,
brightening=False, brightening=False,
brightness=None, brightness=None,
binarization=False, binarization=False,
image_inversion=False, image_inversion=False,
channels_shuffling=False, channels_shuffling=False,
shuffle_indexes=None, shuffle_indexes=None,
white_noise_strap=False, white_noise_strap=False,
textline_skewing=False, textline_skewing=False,
textline_skewing_bin=False, textline_skewing_bin=False,
skewing_amplitudes=None, skewing_amplitudes=None,
textline_left_in_depth=False, textline_left_in_depth=False,
textline_left_in_depth_bin=False, textline_left_in_depth_bin=False,
textline_right_in_depth=False, textline_right_in_depth=False,
textline_right_in_depth_bin=False, textline_right_in_depth_bin=False,
textline_up_in_depth=False, textline_up_in_depth=False,
textline_up_in_depth_bin=False, textline_up_in_depth_bin=False,
textline_down_in_depth=False, textline_down_in_depth=False,
textline_down_in_depth_bin=False, textline_down_in_depth_bin=False,
pepper_aug=False, pepper_aug=False,
pepper_bin_aug=False, pepper_bin_aug=False,
pepper_indexes=None, pepper_indexes=None,
dir_img_bin=None, dir_img_bin=None,
add_red_textlines=False, add_red_textlines=False,
adding_rgb_background=False, adding_rgb_background=False,
dir_rgb_backgrounds=None, dir_rgb_backgrounds=None,
adding_rgb_foreground=False, adding_rgb_foreground=False,
dir_rgb_foregrounds=None, dir_rgb_foregrounds=None,
number_of_backgrounds_per_image=None, number_of_backgrounds_per_image=None,
list_all_possible_background_images=None, list_all_possible_background_images=None,
list_all_possible_foreground_rgbs=None, list_all_possible_foreground_rgbs=None,
**kwargs
): ):
def scale_image(img): def scale_image(img):
return scale_padd_image_for_ocr(img, input_height, input_width).astype(np.float32) / 255. return scale_padd_image_for_ocr(img, input_height, input_width).astype(np.float32) / 255.