mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: remove data_gen in favor of tf.data pipelines
instead of looping over file pairs indefinitely, yielding Numpy arrays: re-use `keras.utils.image_dataset_from_directory` here as well, but with img/label generators zipped together (thus, everything will already be loaded/prefetched on the GPU)
This commit is contained in:
parent
83c2408192
commit
7888fa5968
2 changed files with 32 additions and 67 deletions
|
|
@ -13,6 +13,7 @@ from tensorflow.keras.models import load_model
|
||||||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||||
from tensorflow.keras.layers import StringLookup
|
from tensorflow.keras.layers import StringLookup
|
||||||
from tensorflow.keras.utils import image_dataset_from_directory
|
from tensorflow.keras.utils import image_dataset_from_directory
|
||||||
|
from tensorflow.keras.backend import one_hot
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
from sacred.config import create_captured_function
|
from sacred.config import create_captured_function
|
||||||
|
|
||||||
|
|
@ -36,7 +37,6 @@ from .models import (
|
||||||
RESNET50_WEIGHTS_URL
|
RESNET50_WEIGHTS_URL
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
data_gen,
|
|
||||||
generate_arrays_from_folder_reading_order,
|
generate_arrays_from_folder_reading_order,
|
||||||
get_one_hot,
|
get_one_hot,
|
||||||
preprocess_imgs,
|
preprocess_imgs,
|
||||||
|
|
@ -435,43 +435,46 @@ def run(_config,
|
||||||
sparse_y_true=False,
|
sparse_y_true=False,
|
||||||
sparse_y_pred=False)])
|
sparse_y_pred=False)])
|
||||||
|
|
||||||
# generating train and evaluation data
|
def get_dataset(dir_imgs, dir_labs, shuffle=None):
|
||||||
gen_kwargs = dict(batch_size=n_batch,
|
gen_kwargs = dict(labels=None,
|
||||||
input_height=input_height,
|
label_mode=None,
|
||||||
input_width=input_width,
|
batch_size=1, # batch after zip below
|
||||||
n_classes=n_classes,
|
image_size=(input_height, input_width),
|
||||||
task=task)
|
color_mode='rgb',
|
||||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, **gen_kwargs)
|
shuffle=shuffle is not None,
|
||||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, **gen_kwargs)
|
seed=shuffle,
|
||||||
|
interpolation='nearest',
|
||||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
crop_to_aspect_ratio=False,
|
||||||
##score_best=[]
|
# Keras 3 only...
|
||||||
##score_best.append(0)
|
#pad_to_aspect_ratio=False,
|
||||||
|
#data_format='channel_last',
|
||||||
|
#verbose=False,
|
||||||
|
)
|
||||||
|
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
|
||||||
|
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
|
||||||
|
if task in ["segmentation", "binarization"]:
|
||||||
|
@tf.function
|
||||||
|
def to_categorical(seg):
|
||||||
|
seg = tf.image.rgb_to_grayscale(seg)
|
||||||
|
seg = tf.cast(seg, tf.int8)
|
||||||
|
seg = tf.squeeze(seg, axis=-1)
|
||||||
|
return one_hot(seg, n_classes)
|
||||||
|
lab_gen = lab_gen.map(to_categorical)
|
||||||
|
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
||||||
|
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
||||||
|
val_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
||||||
|
|
||||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||||
SaveWeightsAfterSteps(0, dir_output, _config)]
|
SaveWeightsAfterSteps(0, dir_output, _config)]
|
||||||
if save_interval:
|
if save_interval:
|
||||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||||
|
|
||||||
steps_train = len(os.listdir(dir_flow_train_imgs)) // n_batch # - 1
|
|
||||||
steps_val = len(os.listdir(dir_flow_eval_imgs)) // n_batch
|
|
||||||
_log.info("training on %d batches in %d epochs", steps_train, n_epochs)
|
|
||||||
_log.info("validating on %d batches", steps_val)
|
|
||||||
model.fit(
|
model.fit(
|
||||||
train_gen,
|
train_gen.prefetch(tf.data.AUTOTUNE), # .repeat()??
|
||||||
steps_per_epoch=steps_train,
|
validation_data=val_gen.prefetch(tf.data.AUTOTUNE),
|
||||||
validation_data=val_gen,
|
|
||||||
#validation_steps=1, # rs: only one batch??
|
|
||||||
validation_steps=steps_val,
|
|
||||||
epochs=n_epochs,
|
epochs=n_epochs,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
initial_epoch=index_start)
|
initial_epoch=index_start)
|
||||||
|
|
||||||
#os.system('rm -rf '+dir_train_flowing)
|
|
||||||
#os.system('rm -rf '+dir_eval_flowing)
|
|
||||||
|
|
||||||
#model.save(dir_output+'/'+'model'+'.h5')
|
|
||||||
|
|
||||||
elif task=="cnn-rnn-ocr":
|
elif task=="cnn-rnn-ocr":
|
||||||
|
|
||||||
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
||||||
|
|
@ -524,7 +527,7 @@ def run(_config,
|
||||||
drop_remainder=True,
|
drop_remainder=True,
|
||||||
#num_parallel_calls=tf.data.AUTOTUNE,
|
#num_parallel_calls=tf.data.AUTOTUNE,
|
||||||
)
|
)
|
||||||
train_ds = train_ds.repeat().shuffle().prefetch(20)
|
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
|
||||||
|
|
||||||
#initial_learning_rate = 1e-4
|
#initial_learning_rate = 1e-4
|
||||||
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
#decay_steps = int (n_epochs * ( len_dataset / n_batch ))
|
||||||
|
|
|
||||||
|
|
@ -596,44 +596,6 @@ def generate_arrays_from_folder_reading_order(classes_file_dir, modal_dir, n_bat
|
||||||
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
|
ret_y= np.zeros((n_batch, n_classes)).astype(np.int16)
|
||||||
batchcount = 0
|
batchcount = 0
|
||||||
|
|
||||||
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
|
|
||||||
c = 0
|
|
||||||
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
|
|
||||||
random.shuffle(n)
|
|
||||||
img = np.zeros((batch_size, input_height, input_width, 3), dtype=float)
|
|
||||||
mask = np.zeros((batch_size, input_height, input_width, n_classes), dtype=float)
|
|
||||||
while True:
|
|
||||||
for i in range(c, c + batch_size): # initially from 0 to 16, c = 0.
|
|
||||||
try:
|
|
||||||
filename = os.path.splitext(n[i])[0]
|
|
||||||
|
|
||||||
train_img = cv2.imread(img_folder + '/' + n[i]) / 255.
|
|
||||||
train_img = cv2.resize(train_img, (input_width, input_height),
|
|
||||||
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
|
||||||
|
|
||||||
img[i - c, :] = train_img # add to array - img[0], img[1], and so on.
|
|
||||||
if task == "segmentation" or task=="binarization":
|
|
||||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
|
||||||
train_mask = resize_image(train_mask, input_height, input_width)
|
|
||||||
train_mask = get_one_hot(train_mask, input_height, input_width, n_classes)
|
|
||||||
elif task == "enhancement":
|
|
||||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255.
|
|
||||||
train_mask = resize_image(train_mask, input_height, input_width)
|
|
||||||
|
|
||||||
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
|
||||||
|
|
||||||
mask[i - c, :] = train_mask
|
|
||||||
except Exception as e:
|
|
||||||
print(str(e))
|
|
||||||
img[i - c, :] = 1.
|
|
||||||
mask[i - c, :] = 0.
|
|
||||||
|
|
||||||
c += batch_size
|
|
||||||
if c + batch_size >= len(os.listdir(img_folder)):
|
|
||||||
c = 0
|
|
||||||
random.shuffle(n)
|
|
||||||
yield img, mask
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Use otsu_copy from utils
|
# TODO: Use otsu_copy from utils
|
||||||
def otsu_copy(img):
|
def otsu_copy(img):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue