training: make data pipeline in 7888fa5 more efficient

This commit is contained in:
Robert Sachunsky 2026-02-27 12:53:09 +01:00
parent f8dd5a328c
commit 1cff937e72

View file

@ -633,7 +633,7 @@ def run(_config,
def get_dataset(dir_imgs, dir_labs, shuffle=None):
gen_kwargs = dict(labels=None,
label_mode=None,
batch_size=1, # batch after zip below
batch_size=None, # batch after zip below
image_size=(input_height, input_width),
color_mode='rgb',
shuffle=shuffle is not None,
@ -647,11 +647,12 @@ def run(_config,
)
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
img_gen = img_gen.map(_to_cv2float)
lab_gen = lab_gen.map(_to_cv2float)
img_gen = img_gen.map(_to_cv2float, num_parallel_calls=tf.data.AUTOTUNE)
lab_gen = lab_gen.map(_to_cv2float, num_parallel_calls=tf.data.AUTOTUNE)
if task in ["segmentation", "binarization"]:
lab_gen = lab_gen.map(_to_categorical)
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
lab_gen = lab_gen.map(_to_categorical, num_parallel_calls=tf.data.AUTOTUNE)
ds = tf.data.Dataset.zip(img_gen, lab_gen)
return ds.batch(n_batch, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch