mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training: make data pipeline in 7888fa5 more efficient
This commit is contained in:
parent
f8dd5a328c
commit
1cff937e72
1 changed files with 6 additions and 5 deletions
|
|
@ -633,7 +633,7 @@ def run(_config,
|
|||
def get_dataset(dir_imgs, dir_labs, shuffle=None):
|
||||
gen_kwargs = dict(labels=None,
|
||||
label_mode=None,
|
||||
batch_size=1, # batch after zip below
|
||||
batch_size=None, # batch after zip below
|
||||
image_size=(input_height, input_width),
|
||||
color_mode='rgb',
|
||||
shuffle=shuffle is not None,
|
||||
|
|
@ -647,11 +647,12 @@ def run(_config,
|
|||
)
|
||||
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
|
||||
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
|
||||
img_gen = img_gen.map(_to_cv2float)
|
||||
lab_gen = lab_gen.map(_to_cv2float)
|
||||
img_gen = img_gen.map(_to_cv2float, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
lab_gen = lab_gen.map(_to_cv2float, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
if task in ["segmentation", "binarization"]:
|
||||
lab_gen = lab_gen.map(_to_categorical)
|
||||
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
||||
lab_gen = lab_gen.map(_to_categorical, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
ds = tf.data.Dataset.zip(img_gen, lab_gen)
|
||||
return ds.batch(n_batch, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
||||
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
||||
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue