mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.train: use std Keras data loader for classification
(much more efficient, works with std F1 metric)
This commit is contained in:
parent
f03124f747
commit
5d0c26b629
1 changed files with 8 additions and 11 deletions
|
|
@ -23,8 +23,6 @@ from eynollah.training.models import (
|
|||
from eynollah.training.utils import (
|
||||
data_gen,
|
||||
generate_arrays_from_folder_reading_order,
|
||||
generate_data_from_folder_evaluation,
|
||||
generate_data_from_folder_training,
|
||||
get_one_hot,
|
||||
preprocess_imgs,
|
||||
return_number_of_total_training_data
|
||||
|
|
@ -37,6 +35,7 @@ from tensorflow.keras.optimizers import SGD, Adam
|
|||
from tensorflow.keras.metrics import MeanIoU, F1Score
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||
from tensorflow.keras.utils import image_dataset_from_directory
|
||||
from sacred import Experiment
|
||||
from sacred.config import create_captured_function
|
||||
from tqdm import tqdm
|
||||
|
|
@ -430,13 +429,13 @@ def run(_config,
|
|||
metrics=['accuracy', F1Score(average='macro', name='f1')])
|
||||
|
||||
list_classes = list(classification_classes_name.values())
|
||||
trainXY = generate_data_from_folder(
|
||||
dir_train, n_batch, input_height, input_width, n_classes, list_classes, shuffle=True)
|
||||
testXY = generate_data_from_folder(
|
||||
dir_eval, n_batch, input_height, input_width, n_classes, list_classes)
|
||||
epoch_size_train = return_number_of_total_training_data(dir_train)
|
||||
epoch_size_eval = return_number_of_total_training_data(dir_eval)
|
||||
|
||||
data_args = dict(label_mode="categorical",
|
||||
class_names=list_classes,
|
||||
batch_size=n_batch,
|
||||
image_size=(input_height, input_width),
|
||||
interpolation="nearest")
|
||||
trainXY = image_dataset_from_directory(dir_train, shuffle=True, **data_args)
|
||||
testXY = image_dataset_from_directory(dir_eval, shuffle=False, **data_args)
|
||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||
SaveWeightsAfterSteps(0, dir_output, _config,
|
||||
monitor='val_f1',
|
||||
|
|
@ -444,10 +443,8 @@ def run(_config,
|
|||
mode='max')]
|
||||
|
||||
history = model.fit(trainXY,
|
||||
steps_per_epoch=epoch_size_train // n_batch,
|
||||
#class_weight=weights)
|
||||
validation_data=testXY,
|
||||
validation_steps=epoch_size_eval // n_batch,
|
||||
verbose=1,
|
||||
epochs=n_epochs,
|
||||
callbacks=callbacks,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue