mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training: use mixed precision and XLA (commented; does not work, yet)
This commit is contained in:
parent
c1d8a72edc
commit
c6d9dd7945
2 changed files with 7 additions and 3 deletions
|
|
@ -8,7 +8,9 @@ from tensorflow.keras.initializers import Zeros
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def focal_loss(gamma=2., alpha=4.):
|
EPS = K.epsilon()
|
||||||
|
|
||||||
|
def focal_loss(gamma=2., alpha=4., epsilon=EPS):
|
||||||
gamma = float(gamma)
|
gamma = float(gamma)
|
||||||
alpha = float(alpha)
|
alpha = float(alpha)
|
||||||
|
|
||||||
|
|
@ -32,7 +34,6 @@ def focal_loss(gamma=2., alpha=4.):
|
||||||
Returns:
|
Returns:
|
||||||
[tensor] -- loss.
|
[tensor] -- loss.
|
||||||
"""
|
"""
|
||||||
epsilon = 1.e-9
|
|
||||||
y_true = tf.convert_to_tensor(y_true, tf.float32)
|
y_true = tf.convert_to_tensor(y_true, tf.float32)
|
||||||
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
|
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
|
||||||
|
|
||||||
|
|
@ -153,7 +154,7 @@ def generalized_dice_loss(y_true, y_pred):
|
||||||
|
|
||||||
|
|
||||||
# TODO: document where this is from
|
# TODO: document where this is from
|
||||||
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
def soft_dice_loss(y_true, y_pred, epsilon=EPS):
|
||||||
"""
|
"""
|
||||||
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
||||||
Assumes the `channels_last` format.
|
Assumes the `channels_last` format.
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,8 @@ def configuration():
|
||||||
try:
|
try:
|
||||||
for device in tf.config.list_physical_devices('GPU'):
|
for device in tf.config.list_physical_devices('GPU'):
|
||||||
tf.config.experimental.set_memory_growth(device, True)
|
tf.config.experimental.set_memory_growth(device, True)
|
||||||
|
#tf.keras.mixed_precision.set_global_policy('mixed_float16')
|
||||||
|
#tf.keras.backend.set_epsilon(1e-4) # avoid NaN from smaller defaults
|
||||||
except:
|
except:
|
||||||
print("no GPU device available", file=sys.stderr)
|
print("no GPU device available", file=sys.stderr)
|
||||||
|
|
||||||
|
|
@ -614,6 +616,7 @@ def run(_config,
|
||||||
else: # task == "enhancement"
|
else: # task == "enhancement"
|
||||||
loss = 'mean_squared_error'
|
loss = 'mean_squared_error'
|
||||||
model.compile(loss=loss,
|
model.compile(loss=loss,
|
||||||
|
#jit_compile=True,
|
||||||
optimizer=Adam(learning_rate=learning_rate),
|
optimizer=Adam(learning_rate=learning_rate),
|
||||||
metrics=metrics)
|
metrics=metrics)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue