mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-01 21:02:00 +01:00
training: make data pipeline in 7888fa5 more efficient
This commit is contained in:
parent
f8dd5a328c
commit
1cff937e72
1 changed files with 6 additions and 5 deletions
|
|
@ -633,7 +633,7 @@ def run(_config,
|
||||||
def get_dataset(dir_imgs, dir_labs, shuffle=None):
|
def get_dataset(dir_imgs, dir_labs, shuffle=None):
|
||||||
gen_kwargs = dict(labels=None,
|
gen_kwargs = dict(labels=None,
|
||||||
label_mode=None,
|
label_mode=None,
|
||||||
batch_size=1, # batch after zip below
|
batch_size=None, # batch after zip below
|
||||||
image_size=(input_height, input_width),
|
image_size=(input_height, input_width),
|
||||||
color_mode='rgb',
|
color_mode='rgb',
|
||||||
shuffle=shuffle is not None,
|
shuffle=shuffle is not None,
|
||||||
|
|
@ -647,11 +647,12 @@ def run(_config,
|
||||||
)
|
)
|
||||||
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
|
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
|
||||||
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
|
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
|
||||||
img_gen = img_gen.map(_to_cv2float)
|
img_gen = img_gen.map(_to_cv2float, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
lab_gen = lab_gen.map(_to_cv2float)
|
lab_gen = lab_gen.map(_to_cv2float, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
if task in ["segmentation", "binarization"]:
|
if task in ["segmentation", "binarization"]:
|
||||||
lab_gen = lab_gen.map(_to_categorical)
|
lab_gen = lab_gen.map(_to_categorical, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
ds = tf.data.Dataset.zip(img_gen, lab_gen)
|
||||||
|
return ds.batch(n_batch, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
||||||
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
||||||
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
|
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue