training: use mixed precision and XLA (commented; does not work, yet)

This commit is contained in:
Robert Sachunsky 2026-02-27 12:57:47 +01:00
parent c1d8a72edc
commit c6d9dd7945
2 changed files with 7 additions and 3 deletions

View file

@ -8,7 +8,9 @@ from tensorflow.keras.initializers import Zeros
import numpy as np
def focal_loss(gamma=2., alpha=4.):
EPS = K.epsilon()
def focal_loss(gamma=2., alpha=4., epsilon=EPS):
gamma = float(gamma)
alpha = float(alpha)
@ -32,7 +34,6 @@ def focal_loss(gamma=2., alpha=4.):
Returns:
[tensor] -- loss.
"""
epsilon = 1.e-9
y_true = tf.convert_to_tensor(y_true, tf.float32)
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
@ -153,7 +154,7 @@ def generalized_dice_loss(y_true, y_pred):
# TODO: document where this is from
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
def soft_dice_loss(y_true, y_pred, epsilon=EPS):
"""
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
Assumes the `channels_last` format.

View file

@ -76,6 +76,8 @@ def configuration():
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
#tf.keras.mixed_precision.set_global_policy('mixed_float16')
#tf.keras.backend.set_epsilon(1e-4) # avoid NaN from smaller defaults
except:
print("no GPU device available", file=sys.stderr)
@ -614,6 +616,7 @@ def run(_config,
else: # task == "enhancement"
loss = 'mean_squared_error'
model.compile(loss=loss,
#jit_compile=True,
optimizer=Adam(learning_rate=learning_rate),
metrics=metrics)