From 7888fa5968d12bf5d485705b90c805f922997d89 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sun, 8 Feb 2026 04:42:44 +0100 Subject: [PATCH] training: remove `data_gen` in favor of tf.data pipelines instead of looping over file pairs indefinitely, yielding Numpy arrays: re-use `keras.utils.image_dataset_from_directory` here as well, but with img/label generators zipped together (thus, everything will already be loaded/prefetched on the GPU) --- src/eynollah/training/train.py | 61 ++++++++++++++++++---------------- src/eynollah/training/utils.py | 38 --------------------- 2 files changed, 32 insertions(+), 67 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 73d5e0b..05a7346 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -13,6 +13,7 @@ from tensorflow.keras.models import load_model from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from tensorflow.keras.layers import StringLookup from tensorflow.keras.utils import image_dataset_from_directory +from tensorflow.keras.backend import one_hot from sacred import Experiment from sacred.config import create_captured_function @@ -36,7 +37,6 @@ from .models import ( RESNET50_WEIGHTS_URL ) from .utils import ( - data_gen, generate_arrays_from_folder_reading_order, get_one_hot, preprocess_imgs, @@ -435,43 +435,46 @@ def run(_config, sparse_y_true=False, sparse_y_pred=False)]) - # generating train and evaluation data - gen_kwargs = dict(batch_size=n_batch, - input_height=input_height, - input_width=input_width, - n_classes=n_classes, - task=task) - train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, **gen_kwargs) - val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, **gen_kwargs) - - ##img_validation_patches = os.listdir(dir_flow_eval_imgs) - ##score_best=[] - ##score_best.append(0) + def get_dataset(dir_imgs, dir_labs, shuffle=None): + gen_kwargs = dict(labels=None, + label_mode=None, + batch_size=1, # batch after zip below + image_size=(input_height, input_width), + color_mode='rgb', + shuffle=shuffle is not None, + seed=shuffle, + interpolation='nearest', + crop_to_aspect_ratio=False, + # Keras 3 only... + #pad_to_aspect_ratio=False, + #data_format='channel_last', + #verbose=False, + ) + img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs) + lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs) + if task in ["segmentation", "binarization"]: + @tf.function + def to_categorical(seg): + seg = tf.image.rgb_to_grayscale(seg) + seg = tf.cast(seg, tf.int8) + seg = tf.squeeze(seg, axis=-1) + return one_hot(seg, n_classes) + lab_gen = lab_gen.map(to_categorical) + return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True) + train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6)) + val_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels) callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), SaveWeightsAfterSteps(0, dir_output, _config)] if save_interval: callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) - - steps_train = len(os.listdir(dir_flow_train_imgs)) // n_batch # - 1 - steps_val = len(os.listdir(dir_flow_eval_imgs)) // n_batch - _log.info("training on %d batches in %d epochs", steps_train, n_epochs) - _log.info("validating on %d batches", steps_val) model.fit( - train_gen, - steps_per_epoch=steps_train, - validation_data=val_gen, - #validation_steps=1, # rs: only one batch?? - validation_steps=steps_val, + train_gen.prefetch(tf.data.AUTOTUNE), # .repeat()?? + validation_data=val_gen.prefetch(tf.data.AUTOTUNE), epochs=n_epochs, callbacks=callbacks, initial_epoch=index_start) - #os.system('rm -rf '+dir_train_flowing) - #os.system('rm -rf '+dir_eval_flowing) - - #model.save(dir_output+'/'+'model'+'.h5') - elif task=="cnn-rnn-ocr": dir_img, dir_lab = get_dirs_or_files(dir_train) @@ -524,7 +527,7 @@ def run(_config, drop_remainder=True, #num_parallel_calls=tf.data.AUTOTUNE, ) - train_ds = train_ds.repeat().shuffle().prefetch(20) + train_ds = train_ds.prefetch(tf.data.AUTOTUNE) #initial_learning_rate = 1e-4 #decay_steps = int (n_epochs * ( len_dataset / n_batch )) diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index a03d539..f2f4bdc 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -596,44 +596,6 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, n_bat ret_y= np.zeros((n_batch, n_classes)).astype(np.int16) batchcount = 0 -def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'): - c = 0 - n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images - random.shuffle(n) - img = np.zeros((batch_size, input_height, input_width, 3), dtype=float) - mask = np.zeros((batch_size, input_height, input_width, n_classes), dtype=float) - while True: - for i in range(c, c + batch_size): # initially from 0 to 16, c = 0. - try: - filename = os.path.splitext(n[i])[0] - - train_img = cv2.imread(img_folder + '/' + n[i]) / 255. - train_img = cv2.resize(train_img, (input_width, input_height), - interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize - - img[i - c, :] = train_img # add to array - img[0], img[1], and so on. - if task == "segmentation" or task=="binarization": - train_mask = cv2.imread(mask_folder + '/' + filename + '.png') - train_mask = resize_image(train_mask, input_height, input_width) - train_mask = get_one_hot(train_mask, input_height, input_width, n_classes) - elif task == "enhancement": - train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255. - train_mask = resize_image(train_mask, input_height, input_width) - - # train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] - - mask[i - c, :] = train_mask - except Exception as e: - print(str(e)) - img[i - c, :] = 1. - mask[i - c, :] = 0. - - c += batch_size - if c + batch_size >= len(os.listdir(img_folder)): - c = 0 - random.shuffle(n) - yield img, mask - # TODO: Use otsu_copy from utils def otsu_copy(img):