mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: remove data_gen in favor of tf.data pipelines
instead of looping over file pairs indefinitely, yielding Numpy arrays: re-use `keras.utils.image_dataset_from_directory` here as well, but with img/label generators zipped together (thus, everything will already be loaded/prefetched on the GPU)
This commit is contained in:
parent
83c2408192
commit
7888fa5968
2 changed files with 32 additions and 67 deletions
|
|
@ -13,6 +13,7 @@ from tensorflow.keras.models import load_model
|
|||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||
from tensorflow.keras.layers import StringLookup
|
||||
from tensorflow.keras.utils import image_dataset_from_directory
|
||||
from tensorflow.keras.backend import one_hot
|
||||
from sacred import Experiment
|
||||
from sacred.config import create_captured_function
|
||||
|
||||
|
|
@ -36,7 +37,6 @@ from .models import (
|
|||
RESNET50_WEIGHTS_URL
|
||||
)
|
||||
from .utils import (
|
||||
data_gen,
|
||||
generate_arrays_from_folder_reading_order,
|
||||
get_one_hot,
|
||||
preprocess_imgs,
|
||||
|
|
@ -435,43 +435,46 @@ def run(_config,
|
|||
sparse_y_true=False,
|
||||
sparse_y_pred=False)])
|
||||
|
||||
# generating train and evaluation data
|
||||
gen_kwargs = dict(batch_size=n_batch,
|
||||
input_height=input_height,
|
||||
input_width=input_width,
|
||||
n_classes=n_classes,
|
||||
task=task)
|
||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, **gen_kwargs)
|
||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, **gen_kwargs)
|
||||
|
||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
def get_dataset(dir_imgs, dir_labs, shuffle=None):
|
||||
gen_kwargs = dict(labels=None,
|
||||
label_mode=None,
|
||||
batch_size=1, # batch after zip below
|
||||
image_size=(input_height, input_width),
|
||||
color_mode='rgb',
|
||||
shuffle=shuffle is not None,
|
||||
seed=shuffle,
|
||||
interpolation='nearest',
|
||||
crop_to_aspect_ratio=False,
|
||||
# Keras 3 only...
|
||||
#pad_to_aspect_ratio=False,
|
||||
#data_format='channel_last',
|
||||
#verbose=False,
|
||||
)
|
||||
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
|
||||
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
|
||||
if task in ["segmentation", "binarization"]:
|
||||
@tf.function
|
||||
def to_categorical(seg):
|
||||
seg = tf.image.rgb_to_grayscale(seg)
|
||||
seg = tf.cast(seg, tf.int8)
|
||||
seg = tf.squeeze(seg, axis=-1)
|
||||
return one_hot(seg, n_classes)
|
||||
lab_gen = lab_gen.map(to_categorical)
|
||||
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
||||
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
||||
val_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
||||
|
||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||
SaveWeightsAfterSteps(0, dir_output, _config)]
|
||||
if save_interval:
|
||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||
|
||||
steps_train = len(os.listdir(dir_flow_train_imgs)) // n_batch # - 1
|
||||
steps_val = len(os.listdir(dir_flow_eval_imgs)) // n_batch
|
||||
_log.info("training on %d batches in %d epochs", steps_train, n_epochs)
|
||||
_log.info("validating on %d batches", steps_val)
|
||||
model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=steps_train,
|
||||
validation_data=val_gen,
|
||||
#validation_steps=1, # rs: only one batch??
|
||||
validation_steps=steps_val,
|
||||
train_gen.prefetch(tf.data.AUTOTUNE), # .repeat()??
|
||||
validation_data=val_gen.prefetch(tf.data.AUTOTUNE),
|
||||
epochs=n_epochs,
|
||||
callbacks=callbacks,
|
||||
initial_epoch=index_start)
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
#os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
|
||||
elif task=="cnn-rnn-ocr":
|
||||
|
||||
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
||||
|
|
@ -524,7 +527,7 @@ def run(_config,
|
|||
drop_remainder=True,
|
||||
#num_parallel_calls=tf.data.AUTOTUNE,
|
||||
)
|
||||
train_ds = train_ds.repeat().shuffle().prefetch(20)
|
||||
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
#initial_learning_rate = 1e-4
|
||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||
|
|
|
|||
|
|
@ -596,44 +596,6 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, n_bat
|
|||
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
|
||||
batchcount = 0
|
||||
|
||||
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
|
||||
c = 0
|
||||
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
|
||||
random.shuffle(n)
|
||||
img = np.zeros((batch_size, input_height, input_width, 3), dtype=float)
|
||||
mask = np.zeros((batch_size, input_height, input_width, n_classes), dtype=float)
|
||||
while True:
|
||||
for i in range(c, c + batch_size): # initially from 0 to 16, c = 0.
|
||||
try:
|
||||
filename = os.path.splitext(n[i])[0]
|
||||
|
||||
train_img = cv2.imread(img_folder + '/' + n[i]) / 255.
|
||||
train_img = cv2.resize(train_img, (input_width, input_height),
|
||||
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
||||
|
||||
img[i - c, :] = train_img # add to array - img[0], img[1], and so on.
|
||||
if task == "segmentation" or task=="binarization":
|
||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
||||
train_mask = resize_image(train_mask, input_height, input_width)
|
||||
train_mask = get_one_hot(train_mask, input_height, input_width, n_classes)
|
||||
elif task == "enhancement":
|
||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255.
|
||||
train_mask = resize_image(train_mask, input_height, input_width)
|
||||
|
||||
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
||||
|
||||
mask[i - c, :] = train_mask
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
img[i - c, :] = 1.
|
||||
mask[i - c, :] = 0.
|
||||
|
||||
c += batch_size
|
||||
if c + batch_size >= len(os.listdir(img_folder)):
|
||||
c = 0
|
||||
random.shuffle(n)
|
||||
yield img, mask
|
||||
|
||||
|
||||
# TODO: Use otsu_copy from utils
|
||||
def otsu_copy(img):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue