diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 1e2ab3e..344522a 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -34,6 +34,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf from tensorflow.keras.optimizers import SGD, Adam +from tensorflow.keras.metrics import MeanIoU from tensorflow.keras.models import load_model from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from sacred import Experiment @@ -374,7 +375,11 @@ def run(_config, loss = 'mean_squared_error' model.compile(loss=loss, optimizer=Adam(learning_rate=learning_rate), - metrics=['accuracy']) + metrics=['accuracy', MeanIoU(n_classes, + name='iou', + ignore_class=0, + sparse_y_true=False, + sparse_y_pred=False)]) # generating train and evaluation data gen_kwargs = dict(batch_size=n_batch,