mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.train: use std Keras data loader for classification
(much more efficient, works with std F1 metric)
This commit is contained in:
parent
f03124f747
commit
5d0c26b629
1 changed files with 8 additions and 11 deletions
|
|
@ -23,8 +23,6 @@ from eynollah.training.models import (
|
||||||
from eynollah.training.utils import (
|
from eynollah.training.utils import (
|
||||||
data_gen,
|
data_gen,
|
||||||
generate_arrays_from_folder_reading_order,
|
generate_arrays_from_folder_reading_order,
|
||||||
generate_data_from_folder_evaluation,
|
|
||||||
generate_data_from_folder_training,
|
|
||||||
get_one_hot,
|
get_one_hot,
|
||||||
preprocess_imgs,
|
preprocess_imgs,
|
||||||
return_number_of_total_training_data
|
return_number_of_total_training_data
|
||||||
|
|
@ -37,6 +35,7 @@ from tensorflow.keras.optimizers import SGD, Adam
|
||||||
from tensorflow.keras.metrics import MeanIoU, F1Score
|
from tensorflow.keras.metrics import MeanIoU, F1Score
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||||
|
from tensorflow.keras.utils import image_dataset_from_directory
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
from sacred.config import create_captured_function
|
from sacred.config import create_captured_function
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
@ -430,13 +429,13 @@ def run(_config,
|
||||||
metrics=['accuracy', F1Score(average='macro', name='f1')])
|
metrics=['accuracy', F1Score(average='macro', name='f1')])
|
||||||
|
|
||||||
list_classes = list(classification_classes_name.values())
|
list_classes = list(classification_classes_name.values())
|
||||||
trainXY = generate_data_from_folder(
|
data_args = dict(label_mode="categorical",
|
||||||
dir_train, n_batch, input_height, input_width, n_classes, list_classes, shuffle=True)
|
class_names=list_classes,
|
||||||
testXY = generate_data_from_folder(
|
batch_size=n_batch,
|
||||||
dir_eval, n_batch, input_height, input_width, n_classes, list_classes)
|
image_size=(input_height, input_width),
|
||||||
epoch_size_train = return_number_of_total_training_data(dir_train)
|
interpolation="nearest")
|
||||||
epoch_size_eval = return_number_of_total_training_data(dir_eval)
|
trainXY = image_dataset_from_directory(dir_train, shuffle=True, **data_args)
|
||||||
|
testXY = image_dataset_from_directory(dir_eval, shuffle=False, **data_args)
|
||||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||||
SaveWeightsAfterSteps(0, dir_output, _config,
|
SaveWeightsAfterSteps(0, dir_output, _config,
|
||||||
monitor='val_f1',
|
monitor='val_f1',
|
||||||
|
|
@ -444,10 +443,8 @@ def run(_config,
|
||||||
mode='max')]
|
mode='max')]
|
||||||
|
|
||||||
history = model.fit(trainXY,
|
history = model.fit(trainXY,
|
||||||
steps_per_epoch=epoch_size_train // n_batch,
|
|
||||||
#class_weight=weights)
|
#class_weight=weights)
|
||||||
validation_data=testXY,
|
validation_data=testXY,
|
||||||
validation_steps=epoch_size_eval // n_batch,
|
|
||||||
verbose=1,
|
verbose=1,
|
||||||
epochs=n_epochs,
|
epochs=n_epochs,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue