mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training: add IoU metric
This commit is contained in:
parent
d1e8a02fd4
commit
25153ad307
1 changed files with 6 additions and 1 deletions
|
|
@ -34,6 +34,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.optimizers import SGD, Adam
|
from tensorflow.keras.optimizers import SGD, Adam
|
||||||
|
from tensorflow.keras.metrics import MeanIoU
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
|
|
@ -374,7 +375,11 @@ def run(_config,
|
||||||
loss = 'mean_squared_error'
|
loss = 'mean_squared_error'
|
||||||
model.compile(loss=loss,
|
model.compile(loss=loss,
|
||||||
optimizer=Adam(learning_rate=learning_rate),
|
optimizer=Adam(learning_rate=learning_rate),
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy', MeanIoU(n_classes,
|
||||||
|
name='iou',
|
||||||
|
ignore_class=0,
|
||||||
|
sparse_y_true=False,
|
||||||
|
sparse_y_pred=False)])
|
||||||
|
|
||||||
# generating train and evaluation data
|
# generating train and evaluation data
|
||||||
gen_kwargs = dict(batch_size=n_batch,
|
gen_kwargs = dict(batch_size=n_batch,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue