training: use mixed precision and XLA (commented; does not work, yet)

This commit is contained in:
Robert Sachunsky 2026-02-27 12:57:47 +01:00
parent c1d8a72edc
commit c6d9dd7945
2 changed files with 7 additions and 3 deletions

View file

@ -8,7 +8,9 @@ from tensorflow.keras.initializers import Zeros
import numpy as np import numpy as np
def focal_loss(gamma=2., alpha=4.): EPS = K.epsilon()
def focal_loss(gamma=2., alpha=4., epsilon=EPS):
gamma = float(gamma) gamma = float(gamma)
alpha = float(alpha) alpha = float(alpha)
@ -32,7 +34,6 @@ def focal_loss(gamma=2., alpha=4.):
Returns: Returns:
[tensor] -- loss. [tensor] -- loss.
""" """
epsilon = 1.e-9
y_true = tf.convert_to_tensor(y_true, tf.float32) y_true = tf.convert_to_tensor(y_true, tf.float32)
y_pred = tf.convert_to_tensor(y_pred, tf.float32) y_pred = tf.convert_to_tensor(y_pred, tf.float32)
@ -153,7 +154,7 @@ def generalized_dice_loss(y_true, y_pred):
# TODO: document where this is from # TODO: document where this is from
def soft_dice_loss(y_true, y_pred, epsilon=1e-6): def soft_dice_loss(y_true, y_pred, epsilon=EPS):
""" """
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
Assumes the `channels_last` format. Assumes the `channels_last` format.

View file

@ -76,6 +76,8 @@ def configuration():
try: try:
for device in tf.config.list_physical_devices('GPU'): for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True) tf.config.experimental.set_memory_growth(device, True)
#tf.keras.mixed_precision.set_global_policy('mixed_float16')
#tf.keras.backend.set_epsilon(1e-4) # avoid NaN from smaller defaults
except: except:
print("no GPU device available", file=sys.stderr) print("no GPU device available", file=sys.stderr)
@ -614,6 +616,7 @@ def run(_config,
else: # task == "enhancement" else: # task == "enhancement"
loss = 'mean_squared_error' loss = 'mean_squared_error'
model.compile(loss=loss, model.compile(loss=loss,
#jit_compile=True,
optimizer=Adam(learning_rate=learning_rate), optimizer=Adam(learning_rate=learning_rate),
metrics=metrics) metrics=metrics)