From c6d9dd7945e745ed5fde140982d08e1fb7e15c39 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Fri, 27 Feb 2026 12:57:47 +0100 Subject: [PATCH] training: use mixed precision and XLA (commented; does not work, yet) --- src/eynollah/training/metrics.py | 7 ++++--- src/eynollah/training/train.py | 3 +++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/eynollah/training/metrics.py b/src/eynollah/training/metrics.py index 56dc732..5955888 100644 --- a/src/eynollah/training/metrics.py +++ b/src/eynollah/training/metrics.py @@ -8,7 +8,9 @@ from tensorflow.keras.initializers import Zeros import numpy as np -def focal_loss(gamma=2., alpha=4.): +EPS = K.epsilon() + +def focal_loss(gamma=2., alpha=4., epsilon=EPS): gamma = float(gamma) alpha = float(alpha) @@ -32,7 +34,6 @@ def focal_loss(gamma=2., alpha=4.): Returns: [tensor] -- loss. """ - epsilon = 1.e-9 y_true = tf.convert_to_tensor(y_true, tf.float32) y_pred = tf.convert_to_tensor(y_pred, tf.float32) @@ -153,7 +154,7 @@ def generalized_dice_loss(y_true, y_pred): # TODO: document where this is from -def soft_dice_loss(y_true, y_pred, epsilon=1e-6): +def soft_dice_loss(y_true, y_pred, epsilon=EPS): """ Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. Assumes the `channels_last` format. diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 63f7717..4d997e5 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -76,6 +76,8 @@ def configuration(): try: for device in tf.config.list_physical_devices('GPU'): tf.config.experimental.set_memory_growth(device, True) + #tf.keras.mixed_precision.set_global_policy('mixed_float16') + #tf.keras.backend.set_epsilon(1e-4) # avoid NaN from smaller defaults except: print("no GPU device available", file=sys.stderr) @@ -614,6 +616,7 @@ def run(_config, else: # task == "enhancement" loss = 'mean_squared_error' model.compile(loss=loss, + #jit_compile=True, optimizer=Adam(learning_rate=learning_rate), metrics=metrics)