diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 6d104dc..df3eac6 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -10,7 +10,7 @@ import tensorflow as tf from tensorflow.keras.optimizers import SGD, Adam from tensorflow.keras.metrics import MeanIoU, F1Score from tensorflow.keras.models import load_model -from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard +from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping from tensorflow.keras.layers import StringLookup from tensorflow.keras.utils import image_dataset_from_directory from tensorflow.keras.backend import one_hot @@ -606,12 +606,12 @@ def run(_config, elif task=="cnn-rnn-ocr": - dir_img, dir_lab = get_dirs_or_files(dir_train) - dir_img_val, dir_lab_val = get_dirs_or_files(dir_eval) - imgs_list = list(os.listdir(dir_img)) - labs_list = list(os.listdir(dir_lab)) - imgs_list_val = list(os.listdir(dir_img_val)) - labs_list_val = list(os.listdir(dir_lab_val)) + dir_img_train, dir_lab_train = get_dirs_or_files(dir_train) + dir_img_valdn, dir_lab_valdn = get_dirs_or_files(dir_eval) + imgs_list_train = list(os.listdir(dir_img_train)) + labs_list_train = list(os.listdir(dir_lab_train)) + imgs_list_valdn = list(os.listdir(dir_img_valdn)) + labs_list_valdn = list(os.listdir(dir_lab_valdn)) with open(characters_txt_file, 'r') as char_txt_f: characters = json.load(char_txt_f) @@ -631,20 +631,20 @@ def run(_config, #print(model.summary()) # todo: use Dataset.map() on Dataset.list_files() - # todo: test_ds - def gen(): - return preprocess_imgs(_config, - imgs_list, - labs_list, - dir_img, - dir_lab, - None, # no file I/O, but in-memory - None, # no file I/O, but in-memory - # extra+overrides - char_to_num=char_to_num, - padding_token=padding_token - ) - train_ds = (tf.data.Dataset.from_generator(gen, (tf.float32, tf.int64)) + def get_dataset(imgs_list, labs_list, dir_img, dir_lab): + def gen(): + return preprocess_imgs(_config, + imgs_list, + labs_list, + dir_img, + dir_lab, + None, # no file I/O, but in-memory + None, # no file I/O, but in-memory + # extra+overrides + char_to_num=char_to_num, + padding_token=padding_token + ) + return (tf.data.Dataset.from_generator(gen, (tf.float32, tf.int64)) .padded_batch(n_batch, padded_shapes=([input_height, input_width, 3], [None]), padding_values=(None, tf.constant(padding_token, dtype=tf.int64)), @@ -653,7 +653,15 @@ def run(_config, ) .map(lambda x, y: {"image": x, "label": y}) .prefetch(tf.data.AUTOTUNE) - ) + ) + train_ds = get_dataset(imgs_list_train, + labs_list_train, + dir_img_train, + dir_lab_train) + valdn_ds = get_dataset(imgs_list_valdn, + labs_list_valdn, + dir_img_valdn, + dir_lab_valdn) #initial_learning_rate = 1e-4 #decay_steps = int (n_epochs * ( len_dataset / n_batch )) @@ -669,7 +677,7 @@ def run(_config, callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) model.fit( train_ds, - #validation_data=test_ds, + validation_data=valdn_ds, verbose=1, epochs=n_epochs, callbacks=callbacks,