From 1cff937e72154a6440819f031a631ecacda16a39 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Fri, 27 Feb 2026 12:53:09 +0100 Subject: [PATCH] training: make data pipeline in 7888fa5 more efficient --- src/eynollah/training/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 30e30cb..92a2f49 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -633,7 +633,7 @@ def run(_config, def get_dataset(dir_imgs, dir_labs, shuffle=None): gen_kwargs = dict(labels=None, label_mode=None, - batch_size=1, # batch after zip below + batch_size=None, # batch after zip below image_size=(input_height, input_width), color_mode='rgb', shuffle=shuffle is not None, @@ -647,11 +647,12 @@ def run(_config, ) img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs) lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs) - img_gen = img_gen.map(_to_cv2float) - lab_gen = lab_gen.map(_to_cv2float) + img_gen = img_gen.map(_to_cv2float, num_parallel_calls=tf.data.AUTOTUNE) + lab_gen = lab_gen.map(_to_cv2float, num_parallel_calls=tf.data.AUTOTUNE) if task in ["segmentation", "binarization"]: - lab_gen = lab_gen.map(_to_categorical) - return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True) + lab_gen = lab_gen.map(_to_categorical, num_parallel_calls=tf.data.AUTOTUNE) + ds = tf.data.Dataset.zip(img_gen, lab_gen) + return ds.batch(n_batch, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE) train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6)) valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels) train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch