training: remove data_gen in favor of tf.data pipelines

instead of looping over file pairs indefinitely, yielding
Numpy arrays: re-use `keras.utils.image_dataset_from_directory`
here as well, but with img/label generators zipped together

(thus, everything will already be loaded/prefetched on the GPU)
This commit is contained in:
Robert Sachunsky 2026-02-08 04:42:44 +01:00
parent 83c2408192
commit 7888fa5968
2 changed files with 32 additions and 67 deletions

View file

@ -13,6 +13,7 @@ from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.layers import StringLookup from tensorflow.keras.layers import StringLookup
from tensorflow.keras.utils import image_dataset_from_directory from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.backend import one_hot
from sacred import Experiment from sacred import Experiment
from sacred.config import create_captured_function from sacred.config import create_captured_function
@ -36,7 +37,6 @@ from .models import (
RESNET50_WEIGHTS_URL RESNET50_WEIGHTS_URL
) )
from .utils import ( from .utils import (
data_gen,
generate_arrays_from_folder_reading_order, generate_arrays_from_folder_reading_order,
get_one_hot, get_one_hot,
preprocess_imgs, preprocess_imgs,
@ -435,43 +435,46 @@ def run(_config,
sparse_y_true=False, sparse_y_true=False,
sparse_y_pred=False)]) sparse_y_pred=False)])
# generating train and evaluation data def get_dataset(dir_imgs, dir_labs, shuffle=None):
gen_kwargs = dict(batch_size=n_batch, gen_kwargs = dict(labels=None,
input_height=input_height, label_mode=None,
input_width=input_width, batch_size=1, # batch after zip below
n_classes=n_classes, image_size=(input_height, input_width),
task=task) color_mode='rgb',
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, **gen_kwargs) shuffle=shuffle is not None,
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, **gen_kwargs) seed=shuffle,
interpolation='nearest',
##img_validation_patches = os.listdir(dir_flow_eval_imgs) crop_to_aspect_ratio=False,
##score_best=[] # Keras 3 only...
##score_best.append(0) #pad_to_aspect_ratio=False,
#data_format='channel_last',
#verbose=False,
)
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
if task in ["segmentation", "binarization"]:
@tf.function
def to_categorical(seg):
seg = tf.image.rgb_to_grayscale(seg)
seg = tf.cast(seg, tf.int8)
seg = tf.squeeze(seg, axis=-1)
return one_hot(seg, n_classes)
lab_gen = lab_gen.map(to_categorical)
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
val_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config)] SaveWeightsAfterSteps(0, dir_output, _config)]
if save_interval: if save_interval:
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
steps_train = len(os.listdir(dir_flow_train_imgs)) // n_batch # - 1
steps_val = len(os.listdir(dir_flow_eval_imgs)) // n_batch
_log.info("training on %d batches in %d epochs", steps_train, n_epochs)
_log.info("validating on %d batches", steps_val)
model.fit( model.fit(
train_gen, train_gen.prefetch(tf.data.AUTOTUNE), # .repeat()??
steps_per_epoch=steps_train, validation_data=val_gen.prefetch(tf.data.AUTOTUNE),
validation_data=val_gen,
#validation_steps=1, # rs: only one batch??
validation_steps=steps_val,
epochs=n_epochs, epochs=n_epochs,
callbacks=callbacks, callbacks=callbacks,
initial_epoch=index_start) initial_epoch=index_start)
#os.system('rm -rf '+dir_train_flowing)
#os.system('rm -rf '+dir_eval_flowing)
#model.save(dir_output+'/'+'model'+'.h5')
elif task=="cnn-rnn-ocr": elif task=="cnn-rnn-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train) dir_img, dir_lab = get_dirs_or_files(dir_train)
@ -524,7 +527,7 @@ def run(_config,
drop_remainder=True, drop_remainder=True,
#num_parallel_calls=tf.data.AUTOTUNE, #num_parallel_calls=tf.data.AUTOTUNE,
) )
train_ds = train_ds.repeat().shuffle().prefetch(20) train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
#initial_learning_rate = 1e-4 #initial_learning_rate = 1e-4
#decay_steps = int (n_epochs * ( len_dataset / n_batch )) #decay_steps = int (n_epochs * ( len_dataset / n_batch ))

View file

@ -596,44 +596,6 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, n_bat
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16) ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
batchcount = 0 batchcount = 0
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
c = 0
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
random.shuffle(n)
img = np.zeros((batch_size, input_height, input_width, 3), dtype=float)
mask = np.zeros((batch_size, input_height, input_width, n_classes), dtype=float)
while True:
for i in range(c, c + batch_size): # initially from 0 to 16, c = 0.
try:
filename = os.path.splitext(n[i])[0]
train_img = cv2.imread(img_folder + '/' + n[i]) / 255.
train_img = cv2.resize(train_img, (input_width, input_height),
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
img[i - c, :] = train_img # add to array - img[0], img[1], and so on.
if task == "segmentation" or task=="binarization":
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
train_mask = resize_image(train_mask, input_height, input_width)
train_mask = get_one_hot(train_mask, input_height, input_width, n_classes)
elif task == "enhancement":
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255.
train_mask = resize_image(train_mask, input_height, input_width)
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
mask[i - c, :] = train_mask
except Exception as e:
print(str(e))
img[i - c, :] = 1.
mask[i - c, :] = 0.
c += batch_size
if c + batch_size >= len(os.listdir(img_folder)):
c = 0
random.shuffle(n)
yield img, mask
# TODO: Use otsu_copy from utils # TODO: Use otsu_copy from utils
def otsu_copy(img): def otsu_copy(img):