training.train: add validation data for OCR

This commit is contained in:
Robert Sachunsky 2026-02-25 00:10:43 +01:00
parent b399db3c00
commit 36e370aa45

View file

@ -10,7 +10,7 @@ import tensorflow as tf
from tensorflow.keras.optimizers import SGD, Adam from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.metrics import MeanIoU, F1Score from tensorflow.keras.metrics import MeanIoU, F1Score
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
from tensorflow.keras.layers import StringLookup from tensorflow.keras.layers import StringLookup
from tensorflow.keras.utils import image_dataset_from_directory from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.backend import one_hot from tensorflow.keras.backend import one_hot
@ -606,12 +606,12 @@ def run(_config,
elif task=="cnn-rnn-ocr": elif task=="cnn-rnn-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train) dir_img_train, dir_lab_train = get_dirs_or_files(dir_train)
dir_img_val, dir_lab_val = get_dirs_or_files(dir_eval) dir_img_valdn, dir_lab_valdn = get_dirs_or_files(dir_eval)
imgs_list = list(os.listdir(dir_img)) imgs_list_train = list(os.listdir(dir_img_train))
labs_list = list(os.listdir(dir_lab)) labs_list_train = list(os.listdir(dir_lab_train))
imgs_list_val = list(os.listdir(dir_img_val)) imgs_list_valdn = list(os.listdir(dir_img_valdn))
labs_list_val = list(os.listdir(dir_lab_val)) labs_list_valdn = list(os.listdir(dir_lab_valdn))
with open(characters_txt_file, 'r') as char_txt_f: with open(characters_txt_file, 'r') as char_txt_f:
characters = json.load(char_txt_f) characters = json.load(char_txt_f)
@ -631,7 +631,7 @@ def run(_config,
#print(model.summary()) #print(model.summary())
# todo: use Dataset.map() on Dataset.list_files() # todo: use Dataset.map() on Dataset.list_files()
# todo: test_ds def get_dataset(imgs_list, labs_list, dir_img, dir_lab):
def gen(): def gen():
return preprocess_imgs(_config, return preprocess_imgs(_config,
imgs_list, imgs_list,
@ -644,7 +644,7 @@ def run(_config,
char_to_num=char_to_num, char_to_num=char_to_num,
padding_token=padding_token padding_token=padding_token
) )
train_ds = (tf.data.Dataset.from_generator(gen, (tf.float32, tf.int64)) return (tf.data.Dataset.from_generator(gen, (tf.float32, tf.int64))
.padded_batch(n_batch, .padded_batch(n_batch,
padded_shapes=([input_height, input_width, 3], [None]), padded_shapes=([input_height, input_width, 3], [None]),
padding_values=(None, tf.constant(padding_token, dtype=tf.int64)), padding_values=(None, tf.constant(padding_token, dtype=tf.int64)),
@ -654,6 +654,14 @@ def run(_config,
.map(lambda x, y: {"image": x, "label": y}) .map(lambda x, y: {"image": x, "label": y})
.prefetch(tf.data.AUTOTUNE) .prefetch(tf.data.AUTOTUNE)
) )
train_ds = get_dataset(imgs_list_train,
labs_list_train,
dir_img_train,
dir_lab_train)
valdn_ds = get_dataset(imgs_list_valdn,
labs_list_valdn,
dir_img_valdn,
dir_lab_valdn)
#initial_learning_rate = 1e-4 #initial_learning_rate = 1e-4
#decay_steps = int (n_epochs * ( len_dataset / n_batch )) #decay_steps = int (n_epochs * ( len_dataset / n_batch ))
@ -669,7 +677,7 @@ def run(_config,
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
model.fit( model.fit(
train_ds, train_ds,
#validation_data=test_ds, validation_data=valdn_ds,
verbose=1, verbose=1,
epochs=n_epochs, epochs=n_epochs,
callbacks=callbacks, callbacks=callbacks,