training.train: use std Keras data loader for classification

(much more efficient, works with std F1 metric)
This commit is contained in:
Robert Sachunsky 2026-02-05 12:02:58 +01:00
parent f03124f747
commit 5d0c26b629

View file

@ -23,8 +23,6 @@ from eynollah.training.models import (
from eynollah.training.utils import ( from eynollah.training.utils import (
data_gen, data_gen,
generate_arrays_from_folder_reading_order, generate_arrays_from_folder_reading_order,
generate_data_from_folder_evaluation,
generate_data_from_folder_training,
get_one_hot, get_one_hot,
preprocess_imgs, preprocess_imgs,
return_number_of_total_training_data return_number_of_total_training_data
@ -37,6 +35,7 @@ from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.metrics import MeanIoU, F1Score from tensorflow.keras.metrics import MeanIoU, F1Score
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.utils import image_dataset_from_directory
from sacred import Experiment from sacred import Experiment
from sacred.config import create_captured_function from sacred.config import create_captured_function
from tqdm import tqdm from tqdm import tqdm
@ -430,13 +429,13 @@ def run(_config,
metrics=['accuracy', F1Score(average='macro', name='f1')]) metrics=['accuracy', F1Score(average='macro', name='f1')])
list_classes = list(classification_classes_name.values()) list_classes = list(classification_classes_name.values())
trainXY = generate_data_from_folder( data_args = dict(label_mode="categorical",
dir_train, n_batch, input_height, input_width, n_classes, list_classes, shuffle=True) class_names=list_classes,
testXY = generate_data_from_folder( batch_size=n_batch,
dir_eval, n_batch, input_height, input_width, n_classes, list_classes) image_size=(input_height, input_width),
epoch_size_train = return_number_of_total_training_data(dir_train) interpolation="nearest")
epoch_size_eval = return_number_of_total_training_data(dir_eval) trainXY = image_dataset_from_directory(dir_train, shuffle=True, **data_args)
testXY = image_dataset_from_directory(dir_eval, shuffle=False, **data_args)
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config, SaveWeightsAfterSteps(0, dir_output, _config,
monitor='val_f1', monitor='val_f1',
@ -444,10 +443,8 @@ def run(_config,
mode='max')] mode='max')]
history = model.fit(trainXY, history = model.fit(trainXY,
steps_per_epoch=epoch_size_train // n_batch,
#class_weight=weights) #class_weight=weights)
validation_data=testXY, validation_data=testXY,
validation_steps=epoch_size_eval // n_batch,
verbose=1, verbose=1,
epochs=n_epochs, epochs=n_epochs,
callbacks=callbacks, callbacks=callbacks,