mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training: use mixed precision and XLA (commented; does not work, yet)
This commit is contained in:
parent
c1d8a72edc
commit
c6d9dd7945
2 changed files with 7 additions and 3 deletions
|
|
@ -8,7 +8,9 @@ from tensorflow.keras.initializers import Zeros
|
|||
import numpy as np
|
||||
|
||||
|
||||
def focal_loss(gamma=2., alpha=4.):
|
||||
EPS = K.epsilon()
|
||||
|
||||
def focal_loss(gamma=2., alpha=4., epsilon=EPS):
|
||||
gamma = float(gamma)
|
||||
alpha = float(alpha)
|
||||
|
||||
|
|
@ -32,7 +34,6 @@ def focal_loss(gamma=2., alpha=4.):
|
|||
Returns:
|
||||
[tensor] -- loss.
|
||||
"""
|
||||
epsilon = 1.e-9
|
||||
y_true = tf.convert_to_tensor(y_true, tf.float32)
|
||||
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
|
||||
|
||||
|
|
@ -153,7 +154,7 @@ def generalized_dice_loss(y_true, y_pred):
|
|||
|
||||
|
||||
# TODO: document where this is from
|
||||
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
||||
def soft_dice_loss(y_true, y_pred, epsilon=EPS):
|
||||
"""
|
||||
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
||||
Assumes the `channels_last` format.
|
||||
|
|
|
|||
|
|
@ -76,6 +76,8 @@ def configuration():
|
|||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
tf.config.experimental.set_memory_growth(device, True)
|
||||
#tf.keras.mixed_precision.set_global_policy('mixed_float16')
|
||||
#tf.keras.backend.set_epsilon(1e-4) # avoid NaN from smaller defaults
|
||||
except:
|
||||
print("no GPU device available", file=sys.stderr)
|
||||
|
||||
|
|
@ -614,6 +616,7 @@ def run(_config,
|
|||
else: # task == "enhancement"
|
||||
loss = 'mean_squared_error'
|
||||
model.compile(loss=loss,
|
||||
#jit_compile=True,
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
metrics=metrics)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue