mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-10-08 23:30:01 +02:00
remove all files except README
This commit is contained in:
parent
16bd5f7691
commit
7bb3ca85a6
10 changed files with 2 additions and 1626 deletions
201
LICENSE
201
LICENSE
|
@ -1,201 +0,0 @@
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
|
@ -1,3 +1,5 @@
|
||||||
# sbb_pixelwise_segmentation
|
# sbb_pixelwise_segmentation
|
||||||
|
|
||||||
This repo has been merged into [eynollah](https://github.com/qurator-spk/eynollah).
|
This repo has been merged into [eynollah](https://github.com/qurator-spk/eynollah).
|
||||||
|
|
||||||
|
For the training tools, see the [`train` folder in eynollah](https://github.com/qurator-spk/eynollah/tree/main/train).
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import tensorflow as tf
|
|
||||||
import warnings
|
|
||||||
from tensorflow.keras.optimizers import *
|
|
||||||
from sacred import Experiment
|
|
||||||
from models import *
|
|
||||||
from utils import *
|
|
||||||
from metrics import *
|
|
||||||
|
|
||||||
|
|
||||||
def configuration():
|
|
||||||
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
|
|
||||||
session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
n_classes = 2
|
|
||||||
input_height = 224
|
|
||||||
input_width = 448
|
|
||||||
weight_decay = 1e-6
|
|
||||||
pretraining = False
|
|
||||||
dir_of_weights = 'model_bin_sbb_ens.h5'
|
|
||||||
|
|
||||||
# configuration()
|
|
||||||
|
|
||||||
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining)
|
|
||||||
model.load_weights(dir_of_weights)
|
|
||||||
model.save('./name_in_another_python_version.h5')
|
|
|
@ -1,30 +0,0 @@
|
||||||
{
|
|
||||||
"n_classes" : 3,
|
|
||||||
"n_epochs" : 2,
|
|
||||||
"input_height" : 448,
|
|
||||||
"input_width" : 672,
|
|
||||||
"weight_decay" : 1e-6,
|
|
||||||
"n_batch" : 2,
|
|
||||||
"learning_rate": 1e-4,
|
|
||||||
"patches" : true,
|
|
||||||
"pretraining" : true,
|
|
||||||
"augmentation" : false,
|
|
||||||
"flip_aug" : false,
|
|
||||||
"blur_aug" : false,
|
|
||||||
"scaling" : true,
|
|
||||||
"binarization" : false,
|
|
||||||
"scaling_bluring" : false,
|
|
||||||
"scaling_binarization" : false,
|
|
||||||
"scaling_flip" : false,
|
|
||||||
"rotation": false,
|
|
||||||
"rotation_not_90": false,
|
|
||||||
"continue_training": false,
|
|
||||||
"index_start": 0,
|
|
||||||
"dir_of_start_model": " ",
|
|
||||||
"weighted_loss": false,
|
|
||||||
"is_loss_soft_dice": false,
|
|
||||||
"data_is_provided": false,
|
|
||||||
"dir_train": "/train",
|
|
||||||
"dir_eval": "/eval",
|
|
||||||
"dir_output": "/output"
|
|
||||||
}
|
|
357
metrics.py
357
metrics.py
|
@ -1,357 +0,0 @@
|
||||||
from tensorflow.keras import backend as K
|
|
||||||
import tensorflow as tf
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def focal_loss(gamma=2., alpha=4.):
|
|
||||||
gamma = float(gamma)
|
|
||||||
alpha = float(alpha)
|
|
||||||
|
|
||||||
def focal_loss_fixed(y_true, y_pred):
|
|
||||||
"""Focal loss for multi-classification
|
|
||||||
FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
|
|
||||||
Notice: y_pred is probability after softmax
|
|
||||||
gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
|
|
||||||
d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
|
|
||||||
Focal Loss for Dense Object Detection
|
|
||||||
https://arxiv.org/abs/1708.02002
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls]
|
|
||||||
y_pred {tensor} -- model's output, shape of [batch_size, num_cls]
|
|
||||||
|
|
||||||
Keyword Arguments:
|
|
||||||
gamma {float} -- (default: {2.0})
|
|
||||||
alpha {float} -- (default: {4.0})
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[tensor] -- loss.
|
|
||||||
"""
|
|
||||||
epsilon = 1.e-9
|
|
||||||
y_true = tf.convert_to_tensor(y_true, tf.float32)
|
|
||||||
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
|
|
||||||
|
|
||||||
model_out = tf.add(y_pred, epsilon)
|
|
||||||
ce = tf.multiply(y_true, -tf.log(model_out))
|
|
||||||
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
|
|
||||||
fl = tf.multiply(alpha, tf.multiply(weight, ce))
|
|
||||||
reduced_fl = tf.reduce_max(fl, axis=1)
|
|
||||||
return tf.reduce_mean(reduced_fl)
|
|
||||||
|
|
||||||
return focal_loss_fixed
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_categorical_crossentropy(weights=None):
|
|
||||||
""" weighted_categorical_crossentropy
|
|
||||||
|
|
||||||
Args:
|
|
||||||
* weights<ktensor|nparray|list>: crossentropy weights
|
|
||||||
Returns:
|
|
||||||
* weighted categorical crossentropy function
|
|
||||||
"""
|
|
||||||
|
|
||||||
def loss(y_true, y_pred):
|
|
||||||
labels_floats = tf.cast(y_true, tf.float32)
|
|
||||||
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, logits=y_pred)
|
|
||||||
|
|
||||||
if weights is not None:
|
|
||||||
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
|
||||||
np.array(weights, dtype=np.float32)[None, None, None])
|
|
||||||
* labels_floats, axis=-1), 1.0)
|
|
||||||
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
|
||||||
return tf.reduce_mean(per_pixel_loss)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
def image_categorical_cross_entropy(y_true, y_pred, weights=None):
|
|
||||||
"""
|
|
||||||
:param y_true: tensor of shape (batch_size, height, width) representing the ground truth.
|
|
||||||
:param y_pred: tensor of shape (batch_size, height, width) representing the prediction.
|
|
||||||
:return: The mean cross-entropy on softmaxed tensors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
labels_floats = tf.cast(y_true, tf.float32)
|
|
||||||
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, logits=y_pred)
|
|
||||||
|
|
||||||
if weights is not None:
|
|
||||||
weight_mask = tf.maximum(
|
|
||||||
tf.reduce_max(tf.constant(
|
|
||||||
np.array(weights, dtype=np.float32)[None, None, None])
|
|
||||||
* labels_floats, axis=-1), 1.0)
|
|
||||||
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
|
||||||
|
|
||||||
return tf.reduce_mean(per_pixel_loss)
|
|
||||||
|
|
||||||
|
|
||||||
def class_tversky(y_true, y_pred):
|
|
||||||
smooth = 1.0 # 1.00
|
|
||||||
|
|
||||||
y_true = K.permute_dimensions(y_true, (3, 1, 2, 0))
|
|
||||||
y_pred = K.permute_dimensions(y_pred, (3, 1, 2, 0))
|
|
||||||
|
|
||||||
y_true_pos = K.batch_flatten(y_true)
|
|
||||||
y_pred_pos = K.batch_flatten(y_pred)
|
|
||||||
true_pos = K.sum(y_true_pos * y_pred_pos, 1)
|
|
||||||
false_neg = K.sum(y_true_pos * (1 - y_pred_pos), 1)
|
|
||||||
false_pos = K.sum((1 - y_true_pos) * y_pred_pos, 1)
|
|
||||||
alpha = 0.2 # 0.5
|
|
||||||
beta = 0.8
|
|
||||||
return (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
|
|
||||||
|
|
||||||
|
|
||||||
def focal_tversky_loss(y_true, y_pred):
|
|
||||||
pt_1 = class_tversky(y_true, y_pred)
|
|
||||||
gamma = 1.3 # 4./3.0#1.3#4.0/3.00# 0.75
|
|
||||||
return K.sum(K.pow((1 - pt_1), gamma))
|
|
||||||
|
|
||||||
|
|
||||||
def generalized_dice_coeff2(y_true, y_pred):
|
|
||||||
n_el = 1
|
|
||||||
for dim in y_true.shape:
|
|
||||||
n_el *= int(dim)
|
|
||||||
n_cl = y_true.shape[-1]
|
|
||||||
w = K.zeros(shape=(n_cl,))
|
|
||||||
w = (K.sum(y_true, axis=(0, 1, 2))) / n_el
|
|
||||||
w = 1 / (w ** 2 + 0.000001)
|
|
||||||
numerator = y_true * y_pred
|
|
||||||
numerator = w * K.sum(numerator, (0, 1, 2))
|
|
||||||
numerator = K.sum(numerator)
|
|
||||||
denominator = y_true + y_pred
|
|
||||||
denominator = w * K.sum(denominator, (0, 1, 2))
|
|
||||||
denominator = K.sum(denominator)
|
|
||||||
return 2 * numerator / denominator
|
|
||||||
|
|
||||||
|
|
||||||
def generalized_dice_coeff(y_true, y_pred):
|
|
||||||
axes = tuple(range(1, len(y_pred.shape) - 1))
|
|
||||||
Ncl = y_pred.shape[-1]
|
|
||||||
w = K.zeros(shape=(Ncl,))
|
|
||||||
w = K.sum(y_true, axis=axes)
|
|
||||||
w = 1 / (w ** 2 + 0.000001)
|
|
||||||
# Compute gen dice coef:
|
|
||||||
numerator = y_true * y_pred
|
|
||||||
numerator = w * K.sum(numerator, axes)
|
|
||||||
numerator = K.sum(numerator)
|
|
||||||
|
|
||||||
denominator = y_true + y_pred
|
|
||||||
denominator = w * K.sum(denominator, axes)
|
|
||||||
denominator = K.sum(denominator)
|
|
||||||
|
|
||||||
gen_dice_coef = 2 * numerator / denominator
|
|
||||||
|
|
||||||
return gen_dice_coef
|
|
||||||
|
|
||||||
|
|
||||||
def generalized_dice_loss(y_true, y_pred):
|
|
||||||
return 1 - generalized_dice_coeff2(y_true, y_pred)
|
|
||||||
|
|
||||||
|
|
||||||
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
|
||||||
"""
|
|
||||||
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
|
||||||
Assumes the `channels_last` format.
|
|
||||||
|
|
||||||
# Arguments
|
|
||||||
y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
|
|
||||||
y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax)
|
|
||||||
epsilon: Used for numerical stability to avoid divide by zero errors
|
|
||||||
|
|
||||||
# References
|
|
||||||
V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
|
|
||||||
https://arxiv.org/abs/1606.04797
|
|
||||||
More details on Dice loss formulation
|
|
||||||
https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
|
|
||||||
|
|
||||||
Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
|
|
||||||
"""
|
|
||||||
|
|
||||||
# skip the batch and class axis for calculating Dice score
|
|
||||||
axes = tuple(range(1, len(y_pred.shape) - 1))
|
|
||||||
|
|
||||||
numerator = 2. * K.sum(y_pred * y_true, axes)
|
|
||||||
|
|
||||||
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
|
||||||
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
|
||||||
|
|
||||||
|
|
||||||
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
|
|
||||||
verbose=False):
|
|
||||||
"""
|
|
||||||
Compute mean metrics of two segmentation masks, via Keras.
|
|
||||||
|
|
||||||
IoU(A,B) = |A & B| / (| A U B|)
|
|
||||||
Dice(A,B) = 2*|A & B| / (|A| + |B|)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_true: true masks, one-hot encoded.
|
|
||||||
y_pred: predicted masks, either softmax outputs, or one-hot encoded.
|
|
||||||
metric_name: metric to be computed, either 'iou' or 'dice'.
|
|
||||||
metric_type: one of 'standard' (default), 'soft', 'naive'.
|
|
||||||
In the standard version, y_pred is one-hot encoded and the mean
|
|
||||||
is taken only over classes that are present (in y_true or y_pred).
|
|
||||||
The 'soft' version of the metrics are computed without one-hot
|
|
||||||
encoding y_pred.
|
|
||||||
The 'naive' version return mean metrics where absent classes contribute
|
|
||||||
to the class mean as 1.0 (instead of being dropped from the mean).
|
|
||||||
drop_last = True: boolean flag to drop last class (usually reserved
|
|
||||||
for background class in semantic segmentation)
|
|
||||||
mean_per_class = False: return mean along batch axis for each class.
|
|
||||||
verbose = False: print intermediate results such as intersection, union
|
|
||||||
(as number of pixels).
|
|
||||||
Returns:
|
|
||||||
IoU/Dice of y_true and y_pred, as a float, unless mean_per_class == True
|
|
||||||
in which case it returns the per-class metric, averaged over the batch.
|
|
||||||
|
|
||||||
Inputs are B*W*H*N tensors, with
|
|
||||||
B = batch size,
|
|
||||||
W = width,
|
|
||||||
H = height,
|
|
||||||
N = number of classes
|
|
||||||
"""
|
|
||||||
|
|
||||||
flag_soft = (metric_type == 'soft')
|
|
||||||
flag_naive_mean = (metric_type == 'naive')
|
|
||||||
|
|
||||||
# always assume one or more classes
|
|
||||||
num_classes = K.shape(y_true)[-1]
|
|
||||||
|
|
||||||
if not flag_soft:
|
|
||||||
# get one-hot encoded masks from y_pred (true masks should already be one-hot)
|
|
||||||
y_pred = K.one_hot(K.argmax(y_pred), num_classes)
|
|
||||||
y_true = K.one_hot(K.argmax(y_true), num_classes)
|
|
||||||
|
|
||||||
# if already one-hot, could have skipped above command
|
|
||||||
# keras uses float32 instead of float64, would give error down (but numpy arrays or keras.to_categorical gives float64)
|
|
||||||
y_true = K.cast(y_true, 'float32')
|
|
||||||
y_pred = K.cast(y_pred, 'float32')
|
|
||||||
|
|
||||||
# intersection and union shapes are batch_size * n_classes (values = area in pixels)
|
|
||||||
axes = (1, 2) # W,H axes of each image
|
|
||||||
intersection = K.sum(K.abs(y_true * y_pred), axis=axes)
|
|
||||||
mask_sum = K.sum(K.abs(y_true), axis=axes) + K.sum(K.abs(y_pred), axis=axes)
|
|
||||||
union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot
|
|
||||||
|
|
||||||
smooth = .001
|
|
||||||
iou = (intersection + smooth) / (union + smooth)
|
|
||||||
dice = 2 * (intersection + smooth) / (mask_sum + smooth)
|
|
||||||
|
|
||||||
metric = {'iou': iou, 'dice': dice}[metric_name]
|
|
||||||
|
|
||||||
# define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise
|
|
||||||
mask = K.cast(K.not_equal(union, 0), 'float32')
|
|
||||||
|
|
||||||
if drop_last:
|
|
||||||
metric = metric[:, :-1]
|
|
||||||
mask = mask[:, :-1]
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print('intersection, union')
|
|
||||||
print(K.eval(intersection), K.eval(union))
|
|
||||||
print(K.eval(intersection / union))
|
|
||||||
|
|
||||||
# return mean metrics: remaining axes are (batch, classes)
|
|
||||||
if flag_naive_mean:
|
|
||||||
return K.mean(metric)
|
|
||||||
|
|
||||||
# take mean only over non-absent classes
|
|
||||||
class_count = K.sum(mask, axis=0)
|
|
||||||
non_zero = tf.greater(class_count, 0)
|
|
||||||
non_zero_sum = tf.boolean_mask(K.sum(metric * mask, axis=0), non_zero)
|
|
||||||
non_zero_count = tf.boolean_mask(class_count, non_zero)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print('Counts of inputs with class present, metrics for non-absent classes')
|
|
||||||
print(K.eval(class_count), K.eval(non_zero_sum / non_zero_count))
|
|
||||||
|
|
||||||
return K.mean(non_zero_sum / non_zero_count)
|
|
||||||
|
|
||||||
|
|
||||||
def mean_iou(y_true, y_pred, **kwargs):
|
|
||||||
"""
|
|
||||||
Compute mean Intersection over Union of two segmentation masks, via Keras.
|
|
||||||
|
|
||||||
Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs.
|
|
||||||
"""
|
|
||||||
return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def Mean_IOU(y_true, y_pred):
|
|
||||||
nb_classes = K.int_shape(y_pred)[-1]
|
|
||||||
iou = []
|
|
||||||
true_pixels = K.argmax(y_true, axis=-1)
|
|
||||||
pred_pixels = K.argmax(y_pred, axis=-1)
|
|
||||||
void_labels = K.equal(K.sum(y_true, axis=-1), 0)
|
|
||||||
for i in range(0, nb_classes): # exclude first label (background) and last label (void)
|
|
||||||
true_labels = K.equal(true_pixels, i) # & ~void_labels
|
|
||||||
pred_labels = K.equal(pred_pixels, i) # & ~void_labels
|
|
||||||
inter = tf.to_int32(true_labels & pred_labels)
|
|
||||||
union = tf.to_int32(true_labels | pred_labels)
|
|
||||||
legal_batches = K.sum(tf.to_int32(true_labels), axis=1) > 0
|
|
||||||
ious = K.sum(inter, axis=1) / K.sum(union, axis=1)
|
|
||||||
iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects
|
|
||||||
iou = tf.stack(iou)
|
|
||||||
legal_labels = ~tf.debugging.is_nan(iou)
|
|
||||||
iou = tf.gather(iou, indices=tf.where(legal_labels))
|
|
||||||
return K.mean(iou)
|
|
||||||
|
|
||||||
|
|
||||||
def iou_vahid(y_true, y_pred):
|
|
||||||
nb_classes = tf.shape(y_true)[-1] + tf.to_int32(1)
|
|
||||||
true_pixels = K.argmax(y_true, axis=-1)
|
|
||||||
pred_pixels = K.argmax(y_pred, axis=-1)
|
|
||||||
iou = []
|
|
||||||
|
|
||||||
for i in tf.range(nb_classes):
|
|
||||||
tp = K.sum(tf.to_int32(K.equal(true_pixels, i) & K.equal(pred_pixels, i)))
|
|
||||||
fp = K.sum(tf.to_int32(K.not_equal(true_pixels, i) & K.equal(pred_pixels, i)))
|
|
||||||
fn = K.sum(tf.to_int32(K.equal(true_pixels, i) & K.not_equal(pred_pixels, i)))
|
|
||||||
iouh = tp / (tp + fp + fn)
|
|
||||||
iou.append(iouh)
|
|
||||||
return K.mean(iou)
|
|
||||||
|
|
||||||
|
|
||||||
def IoU_metric(Yi, y_predi):
|
|
||||||
# mean Intersection over Union
|
|
||||||
# Mean IoU = TP/(FN + TP + FP)
|
|
||||||
y_predi = np.argmax(y_predi, axis=3)
|
|
||||||
y_testi = np.argmax(Yi, axis=3)
|
|
||||||
IoUs = []
|
|
||||||
Nclass = int(np.max(Yi)) + 1
|
|
||||||
for c in range(Nclass):
|
|
||||||
TP = np.sum((Yi == c) & (y_predi == c))
|
|
||||||
FP = np.sum((Yi != c) & (y_predi == c))
|
|
||||||
FN = np.sum((Yi == c) & (y_predi != c))
|
|
||||||
IoU = TP / float(TP + FP + FN)
|
|
||||||
IoUs.append(IoU)
|
|
||||||
return K.cast(np.mean(IoUs), dtype='float32')
|
|
||||||
|
|
||||||
|
|
||||||
def IoU_metric_keras(y_true, y_pred):
|
|
||||||
# mean Intersection over Union
|
|
||||||
# Mean IoU = TP/(FN + TP + FP)
|
|
||||||
init = tf.global_variables_initializer()
|
|
||||||
sess = tf.Session()
|
|
||||||
sess.run(init)
|
|
||||||
|
|
||||||
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
|
|
||||||
|
|
||||||
|
|
||||||
def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
|
||||||
"""
|
|
||||||
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
|
|
||||||
= sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
|
|
||||||
|
|
||||||
The jaccard distance loss is usefull for unbalanced datasets. This has been
|
|
||||||
shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
|
|
||||||
gradient.
|
|
||||||
|
|
||||||
Ref: https://en.wikipedia.org/wiki/Jaccard_index
|
|
||||||
|
|
||||||
@url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
|
|
||||||
@author: wassname
|
|
||||||
"""
|
|
||||||
intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
|
|
||||||
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
|
|
||||||
jac = (intersection + smooth) / (sum_ - intersection + smooth)
|
|
||||||
return (1 - jac) * smooth
|
|
294
models.py
294
models.py
|
@ -1,294 +0,0 @@
|
||||||
from tensorflow.keras.models import *
|
|
||||||
from tensorflow.keras.layers import *
|
|
||||||
from tensorflow.keras import layers
|
|
||||||
from tensorflow.keras.regularizers import l2
|
|
||||||
|
|
||||||
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
|
||||||
IMAGE_ORDERING = 'channels_last'
|
|
||||||
MERGE_AXIS = -1
|
|
||||||
|
|
||||||
|
|
||||||
def one_side_pad(x):
|
|
||||||
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
|
|
||||||
if IMAGE_ORDERING == 'channels_first':
|
|
||||||
x = Lambda(lambda x: x[:, :, :-1, :-1])(x)
|
|
||||||
elif IMAGE_ORDERING == 'channels_last':
|
|
||||||
x = Lambda(lambda x: x[:, :-1, :-1, :])(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def identity_block(input_tensor, kernel_size, filters, stage, block):
|
|
||||||
"""The identity block is the block that has no conv layer at shortcut.
|
|
||||||
# Arguments
|
|
||||||
input_tensor: input tensor
|
|
||||||
kernel_size: defualt 3, the kernel size of middle conv layer at main path
|
|
||||||
filters: list of integers, the filterss of 3 conv layer at main path
|
|
||||||
stage: integer, current stage label, used for generating layer names
|
|
||||||
block: 'a','b'..., current block label, used for generating layer names
|
|
||||||
# Returns
|
|
||||||
Output tensor for the block.
|
|
||||||
"""
|
|
||||||
filters1, filters2, filters3 = filters
|
|
||||||
|
|
||||||
if IMAGE_ORDERING == 'channels_last':
|
|
||||||
bn_axis = 3
|
|
||||||
else:
|
|
||||||
bn_axis = 1
|
|
||||||
|
|
||||||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
|
||||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
|
||||||
|
|
||||||
x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2a')(input_tensor)
|
|
||||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
|
|
||||||
x = Activation('relu')(x)
|
|
||||||
|
|
||||||
x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING,
|
|
||||||
padding='same', name=conv_name_base + '2b')(x)
|
|
||||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
|
|
||||||
x = Activation('relu')(x)
|
|
||||||
|
|
||||||
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
|
||||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
|
||||||
|
|
||||||
x = layers.add([x, input_tensor])
|
|
||||||
x = Activation('relu')(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
|
|
||||||
"""conv_block is the block that has a conv layer at shortcut
|
|
||||||
# Arguments
|
|
||||||
input_tensor: input tensor
|
|
||||||
kernel_size: defualt 3, the kernel size of middle conv layer at main path
|
|
||||||
filters: list of integers, the filterss of 3 conv layer at main path
|
|
||||||
stage: integer, current stage label, used for generating layer names
|
|
||||||
block: 'a','b'..., current block label, used for generating layer names
|
|
||||||
# Returns
|
|
||||||
Output tensor for the block.
|
|
||||||
Note that from stage 3, the first conv layer at main path is with strides=(2,2)
|
|
||||||
And the shortcut should have strides=(2,2) as well
|
|
||||||
"""
|
|
||||||
filters1, filters2, filters3 = filters
|
|
||||||
|
|
||||||
if IMAGE_ORDERING == 'channels_last':
|
|
||||||
bn_axis = 3
|
|
||||||
else:
|
|
||||||
bn_axis = 1
|
|
||||||
|
|
||||||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
|
||||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
|
||||||
|
|
||||||
x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
|
|
||||||
name=conv_name_base + '2a')(input_tensor)
|
|
||||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
|
|
||||||
x = Activation('relu')(x)
|
|
||||||
|
|
||||||
x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same',
|
|
||||||
name=conv_name_base + '2b')(x)
|
|
||||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
|
|
||||||
x = Activation('relu')(x)
|
|
||||||
|
|
||||||
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
|
||||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
|
||||||
|
|
||||||
shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
|
|
||||||
name=conv_name_base + '1')(input_tensor)
|
|
||||||
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
|
|
||||||
|
|
||||||
x = layers.add([x, shortcut])
|
|
||||||
x = Activation('relu')(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
|
||||||
assert input_height % 32 == 0
|
|
||||||
assert input_width % 32 == 0
|
|
||||||
|
|
||||||
img_input = Input(shape=(input_height, input_width, 3))
|
|
||||||
|
|
||||||
if IMAGE_ORDERING == 'channels_last':
|
|
||||||
bn_axis = 3
|
|
||||||
else:
|
|
||||||
bn_axis = 1
|
|
||||||
|
|
||||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
|
||||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
|
||||||
name='conv1')(x)
|
|
||||||
f1 = x
|
|
||||||
|
|
||||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
|
||||||
x = Activation('relu')(x)
|
|
||||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
|
||||||
|
|
||||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
|
||||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
|
||||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
|
||||||
f2 = one_side_pad(x)
|
|
||||||
|
|
||||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
|
||||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
|
||||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
|
||||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
|
||||||
f3 = x
|
|
||||||
|
|
||||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
|
||||||
f4 = x
|
|
||||||
|
|
||||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
|
||||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
|
||||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
|
||||||
f5 = x
|
|
||||||
|
|
||||||
if pretraining:
|
|
||||||
model = Model(img_input, x).load_weights(resnet50_Weights_path)
|
|
||||||
|
|
||||||
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
|
||||||
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
|
|
||||||
v512_2048 = Activation('relu')(v512_2048)
|
|
||||||
|
|
||||||
v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4)
|
|
||||||
v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024)
|
|
||||||
v512_1024 = Activation('relu')(v512_1024)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048)
|
|
||||||
o = (concatenate([o, v512_1024], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = (Activation('softmax'))(o)
|
|
||||||
|
|
||||||
model = Model(img_input, o)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
|
||||||
assert input_height % 32 == 0
|
|
||||||
assert input_width % 32 == 0
|
|
||||||
|
|
||||||
img_input = Input(shape=(input_height, input_width, 3))
|
|
||||||
|
|
||||||
if IMAGE_ORDERING == 'channels_last':
|
|
||||||
bn_axis = 3
|
|
||||||
else:
|
|
||||||
bn_axis = 1
|
|
||||||
|
|
||||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
|
||||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
|
||||||
name='conv1')(x)
|
|
||||||
f1 = x
|
|
||||||
|
|
||||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
|
||||||
x = Activation('relu')(x)
|
|
||||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
|
||||||
|
|
||||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
|
||||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
|
||||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
|
||||||
f2 = one_side_pad(x)
|
|
||||||
|
|
||||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
|
||||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
|
||||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
|
||||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
|
||||||
f3 = x
|
|
||||||
|
|
||||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
|
||||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
|
||||||
f4 = x
|
|
||||||
|
|
||||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
|
||||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
|
||||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
|
||||||
f5 = x
|
|
||||||
|
|
||||||
if pretraining:
|
|
||||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
|
||||||
|
|
||||||
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
|
|
||||||
f5)
|
|
||||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
|
||||||
v1024_2048 = Activation('relu')(v1024_2048)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
|
||||||
o = (concatenate([o, f4], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
|
||||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
|
||||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay)))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = Activation('relu')(o)
|
|
||||||
|
|
||||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
|
||||||
o = (Activation('softmax'))(o)
|
|
||||||
|
|
||||||
model = Model(img_input, o)
|
|
||||||
|
|
||||||
return model
|
|
|
@ -1,8 +0,0 @@
|
||||||
tensorflow == 2.12.1
|
|
||||||
sacred
|
|
||||||
opencv-python-headless
|
|
||||||
seaborn
|
|
||||||
tqdm
|
|
||||||
imutils
|
|
||||||
numpy
|
|
||||||
scipy
|
|
213
train.py
213
train.py
|
@ -1,213 +0,0 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow.compat.v1.keras.backend import set_session
|
|
||||||
import warnings
|
|
||||||
from tensorflow.keras.optimizers import *
|
|
||||||
from sacred import Experiment
|
|
||||||
from models import *
|
|
||||||
from utils import *
|
|
||||||
from metrics import *
|
|
||||||
from tensorflow.keras.models import load_model
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def configuration():
|
|
||||||
config = tf.compat.v1.ConfigProto()
|
|
||||||
config.gpu_options.allow_growth = True
|
|
||||||
session = tf.compat.v1.Session(config=config)
|
|
||||||
set_session(session)
|
|
||||||
|
|
||||||
|
|
||||||
def get_dirs_or_files(input_data):
|
|
||||||
if os.path.isdir(input_data):
|
|
||||||
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
|
||||||
# Check if training dir exists
|
|
||||||
assert os.path.isdir(image_input), "{} is not a directory".format(image_input)
|
|
||||||
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
|
|
||||||
return image_input, labels_input
|
|
||||||
|
|
||||||
|
|
||||||
ex = Experiment()
|
|
||||||
|
|
||||||
|
|
||||||
@ex.config
|
|
||||||
def config_params():
|
|
||||||
n_classes = None # Number of classes. In the case of binary classification this should be 2.
|
|
||||||
n_epochs = 1 # Number of epochs.
|
|
||||||
input_height = 224 * 1 # Height of model's input in pixels.
|
|
||||||
input_width = 224 * 1 # Width of model's input in pixels.
|
|
||||||
weight_decay = 1e-6 # Weight decay of l2 regularization of model layers.
|
|
||||||
n_batch = 1 # Number of batches at each iteration.
|
|
||||||
learning_rate = 1e-4 # Set the learning rate.
|
|
||||||
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
|
|
||||||
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
|
|
||||||
flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py.
|
|
||||||
blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py.
|
|
||||||
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py.
|
|
||||||
binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images.
|
|
||||||
dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels".
|
|
||||||
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
|
|
||||||
dir_output = None # Directory where the output model will be saved.
|
|
||||||
pretraining = False # Set to true to load pretrained weights of ResNet50 encoder.
|
|
||||||
scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image.
|
|
||||||
scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image.
|
|
||||||
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
|
|
||||||
thetha = [10, -10] # Rotate image by these angles for augmentation.
|
|
||||||
blur_k = ['blur', 'gauss', 'median'] # Blur image for augmentation.
|
|
||||||
scales = [0.5, 2] # Scale patches for augmentation.
|
|
||||||
flip_index = [0, 1, -1] # Flip image for augmentation.
|
|
||||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
|
||||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
|
||||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
|
||||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
|
||||||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
|
||||||
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
|
||||||
|
|
||||||
|
|
||||||
@ex.automain
|
|
||||||
def run(n_classes, n_epochs, input_height,
|
|
||||||
input_width, weight_decay, weighted_loss,
|
|
||||||
index_start, dir_of_start_model, is_loss_soft_dice,
|
|
||||||
n_batch, patches, augmentation, flip_aug,
|
|
||||||
blur_aug, scaling, binarization,
|
|
||||||
blur_k, scales, dir_train, data_is_provided,
|
|
||||||
scaling_bluring, scaling_binarization, rotation,
|
|
||||||
rotation_not_90, thetha, scaling_flip, continue_training,
|
|
||||||
flip_index, dir_eval, dir_output, pretraining, learning_rate):
|
|
||||||
if data_is_provided:
|
|
||||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
|
||||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
|
||||||
|
|
||||||
dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images')
|
|
||||||
dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels')
|
|
||||||
|
|
||||||
dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images')
|
|
||||||
dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels')
|
|
||||||
|
|
||||||
configuration()
|
|
||||||
|
|
||||||
else:
|
|
||||||
dir_img, dir_seg = get_dirs_or_files(dir_train)
|
|
||||||
dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval)
|
|
||||||
|
|
||||||
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
|
||||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
|
||||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
|
||||||
|
|
||||||
dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images/')
|
|
||||||
dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels/')
|
|
||||||
|
|
||||||
dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images/')
|
|
||||||
dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels/')
|
|
||||||
|
|
||||||
if os.path.isdir(dir_train_flowing):
|
|
||||||
os.system('rm -rf ' + dir_train_flowing)
|
|
||||||
os.makedirs(dir_train_flowing)
|
|
||||||
else:
|
|
||||||
os.makedirs(dir_train_flowing)
|
|
||||||
|
|
||||||
if os.path.isdir(dir_eval_flowing):
|
|
||||||
os.system('rm -rf ' + dir_eval_flowing)
|
|
||||||
os.makedirs(dir_eval_flowing)
|
|
||||||
else:
|
|
||||||
os.makedirs(dir_eval_flowing)
|
|
||||||
|
|
||||||
os.mkdir(dir_flow_train_imgs)
|
|
||||||
os.mkdir(dir_flow_train_labels)
|
|
||||||
|
|
||||||
os.mkdir(dir_flow_eval_imgs)
|
|
||||||
os.mkdir(dir_flow_eval_labels)
|
|
||||||
|
|
||||||
# set the gpu configuration
|
|
||||||
configuration()
|
|
||||||
|
|
||||||
# writing patches into a sub-folder in order to be flowed from directory.
|
|
||||||
provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
|
|
||||||
dir_flow_train_labels,
|
|
||||||
input_height, input_width, blur_k, blur_aug,
|
|
||||||
flip_aug, binarization, scaling, scales, flip_index,
|
|
||||||
scaling_bluring, scaling_binarization, rotation,
|
|
||||||
rotation_not_90, thetha, scaling_flip,
|
|
||||||
augmentation=augmentation, patches=patches)
|
|
||||||
|
|
||||||
provide_patches(dir_img_val, dir_seg_val, dir_flow_eval_imgs,
|
|
||||||
dir_flow_eval_labels,
|
|
||||||
input_height, input_width, blur_k, blur_aug,
|
|
||||||
flip_aug, binarization, scaling, scales, flip_index,
|
|
||||||
scaling_bluring, scaling_binarization, rotation,
|
|
||||||
rotation_not_90, thetha, scaling_flip,
|
|
||||||
augmentation=False, patches=patches)
|
|
||||||
|
|
||||||
if weighted_loss:
|
|
||||||
weights = np.zeros(n_classes)
|
|
||||||
if data_is_provided:
|
|
||||||
for obj in os.listdir(dir_flow_train_labels):
|
|
||||||
try:
|
|
||||||
label_obj = cv2.imread(dir_flow_train_labels + '/' + obj)
|
|
||||||
label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes)
|
|
||||||
weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
|
|
||||||
for obj in os.listdir(dir_seg):
|
|
||||||
try:
|
|
||||||
label_obj = cv2.imread(dir_seg + '/' + obj)
|
|
||||||
label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes)
|
|
||||||
weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
weights = 1.00 / weights
|
|
||||||
|
|
||||||
weights = weights / float(np.sum(weights))
|
|
||||||
weights = weights / float(np.min(weights))
|
|
||||||
weights = weights / float(np.sum(weights))
|
|
||||||
|
|
||||||
if continue_training:
|
|
||||||
if is_loss_soft_dice:
|
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
|
||||||
if weighted_loss:
|
|
||||||
model = load_model(dir_of_start_model, compile=True,
|
|
||||||
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
|
||||||
model = load_model(dir_of_start_model, compile=True)
|
|
||||||
else:
|
|
||||||
# get our model.
|
|
||||||
index_start = 0
|
|
||||||
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining)
|
|
||||||
|
|
||||||
# if you want to see the model structure just uncomment model summary.
|
|
||||||
# model.summary()
|
|
||||||
|
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
|
||||||
model.compile(loss='categorical_crossentropy',
|
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
|
||||||
if is_loss_soft_dice:
|
|
||||||
model.compile(loss=soft_dice_loss,
|
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
|
||||||
|
|
||||||
if weighted_loss:
|
|
||||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
|
||||||
|
|
||||||
# generating train and evaluation data
|
|
||||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
|
||||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
|
||||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
|
||||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
|
||||||
|
|
||||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
|
||||||
model.fit_generator(
|
|
||||||
train_gen,
|
|
||||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
|
||||||
validation_data=val_gen,
|
|
||||||
validation_steps=1,
|
|
||||||
epochs=1)
|
|
||||||
model.save(dir_output + '/' + 'model_' + str(i))
|
|
||||||
|
|
||||||
# os.system('rm -rf '+dir_train_flowing)
|
|
||||||
# os.system('rm -rf '+dir_eval_flowing)
|
|
||||||
|
|
||||||
# model.save(dir_output+'/'+'model'+'.h5')
|
|
494
utils.py
494
utils.py
|
@ -1,494 +0,0 @@
|
||||||
import os
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
|
||||||
from scipy.ndimage.interpolation import map_coordinates
|
|
||||||
from scipy.ndimage.filters import gaussian_filter
|
|
||||||
import random
|
|
||||||
from tqdm import tqdm
|
|
||||||
import imutils
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
def bluring(img_in, kind):
|
|
||||||
if kind == 'gauss':
|
|
||||||
img_blur = cv2.GaussianBlur(img_in, (5, 5), 0)
|
|
||||||
elif kind == "median":
|
|
||||||
img_blur = cv2.medianBlur(img_in, 5)
|
|
||||||
elif kind == 'blur':
|
|
||||||
img_blur = cv2.blur(img_in, (5, 5))
|
|
||||||
return img_blur
|
|
||||||
|
|
||||||
|
|
||||||
def elastic_transform(image, alpha, sigma, seedj, random_state=None):
|
|
||||||
"""Elastic deformation of images as described in [Simard2003]_.
|
|
||||||
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
|
||||||
Convolutional Neural Networks applied to Visual Document Analysis", in
|
|
||||||
Proc. of the International Conference on Document Analysis and
|
|
||||||
Recognition, 2003.
|
|
||||||
"""
|
|
||||||
if random_state is None:
|
|
||||||
random_state = np.random.RandomState(seedj)
|
|
||||||
|
|
||||||
shape = image.shape
|
|
||||||
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
|
|
||||||
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
|
|
||||||
dz = np.zeros_like(dx)
|
|
||||||
|
|
||||||
x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
|
|
||||||
indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1))
|
|
||||||
|
|
||||||
distored_image = map_coordinates(image, indices, order=1, mode='reflect')
|
|
||||||
return distored_image.reshape(image.shape)
|
|
||||||
|
|
||||||
|
|
||||||
def rotation_90(img):
|
|
||||||
img_rot = np.zeros((img.shape[1], img.shape[0], img.shape[2]))
|
|
||||||
img_rot[:, :, 0] = img[:, :, 0].T
|
|
||||||
img_rot[:, :, 1] = img[:, :, 1].T
|
|
||||||
img_rot[:, :, 2] = img[:, :, 2].T
|
|
||||||
return img_rot
|
|
||||||
|
|
||||||
|
|
||||||
def rotatedRectWithMaxArea(w, h, angle):
|
|
||||||
"""
|
|
||||||
Given a rectangle of size wxh that has been rotated by 'angle' (in
|
|
||||||
radians), computes the width and height of the largest possible
|
|
||||||
axis-aligned rectangle (maximal area) within the rotated rectangle.
|
|
||||||
"""
|
|
||||||
if w <= 0 or h <= 0:
|
|
||||||
return 0, 0
|
|
||||||
|
|
||||||
width_is_longer = w >= h
|
|
||||||
side_long, side_short = (w, h) if width_is_longer else (h, w)
|
|
||||||
|
|
||||||
# since the solutions for angle, -angle and 180-angle are all the same,
|
|
||||||
# if suffices to look at the first quadrant and the absolute values of sin,cos:
|
|
||||||
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
|
|
||||||
if side_short <= 2. * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10:
|
|
||||||
# half constrained case: two crop corners touch the longer side,
|
|
||||||
# the other two corners are on the mid-line parallel to the longer line
|
|
||||||
x = 0.5 * side_short
|
|
||||||
wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a)
|
|
||||||
else:
|
|
||||||
# fully constrained case: crop touches all 4 sides
|
|
||||||
cos_2a = cos_a * cos_a - sin_a * sin_a
|
|
||||||
wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a
|
|
||||||
|
|
||||||
return wr, hr
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_max_area(image, rotated, rotated_label, angle):
|
|
||||||
""" image: cv2 image matrix object
|
|
||||||
angle: in degree
|
|
||||||
"""
|
|
||||||
wr, hr = rotatedRectWithMaxArea(image.shape[1], image.shape[0],
|
|
||||||
math.radians(angle))
|
|
||||||
h, w, _ = rotated.shape
|
|
||||||
y1 = h // 2 - int(hr / 2)
|
|
||||||
y2 = y1 + int(hr)
|
|
||||||
x1 = w // 2 - int(wr / 2)
|
|
||||||
x2 = x1 + int(wr)
|
|
||||||
return rotated[y1:y2, x1:x2], rotated_label[y1:y2, x1:x2]
|
|
||||||
|
|
||||||
|
|
||||||
def rotation_not_90_func(img, label, thetha):
|
|
||||||
rotated = imutils.rotate(img, thetha)
|
|
||||||
rotated_label = imutils.rotate(label, thetha)
|
|
||||||
return rotate_max_area(img, rotated, rotated_label, thetha)
|
|
||||||
|
|
||||||
|
|
||||||
def color_images(seg, n_classes):
|
|
||||||
ann_u = range(n_classes)
|
|
||||||
if len(np.shape(seg)) == 3:
|
|
||||||
seg = seg[:, :, 0]
|
|
||||||
|
|
||||||
seg_img = np.zeros((np.shape(seg)[0], np.shape(seg)[1], 3)).astype(float)
|
|
||||||
colors = sns.color_palette("hls", n_classes)
|
|
||||||
|
|
||||||
for c in ann_u:
|
|
||||||
c = int(c)
|
|
||||||
segl = (seg == c)
|
|
||||||
seg_img[:, :, 0] += segl * (colors[c][0])
|
|
||||||
seg_img[:, :, 1] += segl * (colors[c][1])
|
|
||||||
seg_img[:, :, 2] += segl * (colors[c][2])
|
|
||||||
return seg_img
|
|
||||||
|
|
||||||
|
|
||||||
def resize_image(seg_in, input_height, input_width):
|
|
||||||
return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
|
||||||
|
|
||||||
|
|
||||||
def get_one_hot(seg, input_height, input_width, n_classes):
|
|
||||||
seg = seg[:, :, 0]
|
|
||||||
seg_f = np.zeros((input_height, input_width, n_classes))
|
|
||||||
for j in range(n_classes):
|
|
||||||
seg_f[:, :, j] = (seg == j).astype(int)
|
|
||||||
return seg_f
|
|
||||||
|
|
||||||
|
|
||||||
def IoU(Yi, y_predi):
|
|
||||||
## mean Intersection over Union
|
|
||||||
## Mean IoU = TP/(FN + TP + FP)
|
|
||||||
|
|
||||||
IoUs = []
|
|
||||||
classes_true = np.unique(Yi)
|
|
||||||
for c in classes_true:
|
|
||||||
TP = np.sum((Yi == c) & (y_predi == c))
|
|
||||||
FP = np.sum((Yi != c) & (y_predi == c))
|
|
||||||
FN = np.sum((Yi == c) & (y_predi != c))
|
|
||||||
IoU = TP / float(TP + FP + FN)
|
|
||||||
print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU))
|
|
||||||
IoUs.append(IoU)
|
|
||||||
mIoU = np.mean(IoUs)
|
|
||||||
print("_________________")
|
|
||||||
print("Mean IoU: {:4.3f}".format(mIoU))
|
|
||||||
return mIoU
|
|
||||||
|
|
||||||
|
|
||||||
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes):
|
|
||||||
c = 0
|
|
||||||
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
|
|
||||||
random.shuffle(n)
|
|
||||||
while True:
|
|
||||||
img = np.zeros((batch_size, input_height, input_width, 3)).astype('float')
|
|
||||||
mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float')
|
|
||||||
|
|
||||||
for i in range(c, c + batch_size): # initially from 0 to 16, c = 0.
|
|
||||||
# print(img_folder+'/'+n[i])
|
|
||||||
|
|
||||||
try:
|
|
||||||
filename = n[i].split('.')[0]
|
|
||||||
|
|
||||||
train_img = cv2.imread(img_folder + '/' + n[i]) / 255.
|
|
||||||
train_img = cv2.resize(train_img, (input_width, input_height),
|
|
||||||
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
|
||||||
|
|
||||||
img[i - c] = train_img # add to array - img[0], img[1], and so on.
|
|
||||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
|
||||||
# print(mask_folder+'/'+filename+'.png')
|
|
||||||
# print(train_mask.shape)
|
|
||||||
train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width,
|
|
||||||
n_classes)
|
|
||||||
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
|
||||||
|
|
||||||
mask[i - c] = train_mask
|
|
||||||
except:
|
|
||||||
img[i - c] = np.ones((input_height, input_width, 3)).astype('float')
|
|
||||||
mask[i - c] = np.zeros((input_height, input_width, n_classes)).astype('float')
|
|
||||||
|
|
||||||
c += batch_size
|
|
||||||
if c + batch_size >= len(os.listdir(img_folder)):
|
|
||||||
c = 0
|
|
||||||
random.shuffle(n)
|
|
||||||
yield img, mask
|
|
||||||
|
|
||||||
|
|
||||||
def otsu_copy(img):
|
|
||||||
img_r = np.zeros(img.shape)
|
|
||||||
img1 = img[:, :, 0]
|
|
||||||
img2 = img[:, :, 1]
|
|
||||||
img3 = img[:, :, 2]
|
|
||||||
_, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
||||||
_, threshold2 = cv2.threshold(img2, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
||||||
_, threshold3 = cv2.threshold(img3, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
||||||
img_r[:, :, 0] = threshold1
|
|
||||||
img_r[:, :, 1] = threshold1
|
|
||||||
img_r[:, :, 2] = threshold1
|
|
||||||
return img_r
|
|
||||||
|
|
||||||
|
|
||||||
def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer):
|
|
||||||
if img.shape[0] < height or img.shape[1] < width:
|
|
||||||
img, label = do_padding(img, label, height, width)
|
|
||||||
|
|
||||||
img_h = img.shape[0]
|
|
||||||
img_w = img.shape[1]
|
|
||||||
|
|
||||||
nxf = img_w / float(width)
|
|
||||||
nyf = img_h / float(height)
|
|
||||||
|
|
||||||
if nxf > int(nxf):
|
|
||||||
nxf = int(nxf) + 1
|
|
||||||
if nyf > int(nyf):
|
|
||||||
nyf = int(nyf) + 1
|
|
||||||
|
|
||||||
nxf = int(nxf)
|
|
||||||
nyf = int(nyf)
|
|
||||||
|
|
||||||
for i in range(nxf):
|
|
||||||
for j in range(nyf):
|
|
||||||
index_x_d = i * width
|
|
||||||
index_x_u = (i + 1) * width
|
|
||||||
|
|
||||||
index_y_d = j * height
|
|
||||||
index_y_u = (j + 1) * height
|
|
||||||
|
|
||||||
if index_x_u > img_w:
|
|
||||||
index_x_u = img_w
|
|
||||||
index_x_d = img_w - width
|
|
||||||
if index_y_u > img_h:
|
|
||||||
index_y_u = img_h
|
|
||||||
index_y_d = img_h - height
|
|
||||||
|
|
||||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
|
||||||
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
|
||||||
|
|
||||||
cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch)
|
|
||||||
cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch)
|
|
||||||
indexer += 1
|
|
||||||
|
|
||||||
return indexer
|
|
||||||
|
|
||||||
|
|
||||||
def do_padding(img, label, height, width):
|
|
||||||
height_new = img.shape[0]
|
|
||||||
width_new = img.shape[1]
|
|
||||||
|
|
||||||
h_start = 0
|
|
||||||
w_start = 0
|
|
||||||
|
|
||||||
if img.shape[0] < height:
|
|
||||||
h_start = int(abs(height - img.shape[0]) / 2.)
|
|
||||||
height_new = height
|
|
||||||
|
|
||||||
if img.shape[1] < width:
|
|
||||||
w_start = int(abs(width - img.shape[1]) / 2.)
|
|
||||||
width_new = width
|
|
||||||
|
|
||||||
img_new = np.ones((height_new, width_new, img.shape[2])).astype(float) * 255
|
|
||||||
label_new = np.zeros((height_new, width_new, label.shape[2])).astype(float)
|
|
||||||
|
|
||||||
img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :])
|
|
||||||
label_new[h_start:h_start + label.shape[0], w_start:w_start + label.shape[1], :] = np.copy(label[:, :, :])
|
|
||||||
|
|
||||||
return img_new, label_new
|
|
||||||
|
|
||||||
|
|
||||||
def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler):
|
|
||||||
if img.shape[0] < height or img.shape[1] < width:
|
|
||||||
img, label = do_padding(img, label, height, width)
|
|
||||||
|
|
||||||
img_h = img.shape[0]
|
|
||||||
img_w = img.shape[1]
|
|
||||||
|
|
||||||
height_scale = int(height * scaler)
|
|
||||||
width_scale = int(width * scaler)
|
|
||||||
|
|
||||||
nxf = img_w / float(width_scale)
|
|
||||||
nyf = img_h / float(height_scale)
|
|
||||||
|
|
||||||
if nxf > int(nxf):
|
|
||||||
nxf = int(nxf) + 1
|
|
||||||
if nyf > int(nyf):
|
|
||||||
nyf = int(nyf) + 1
|
|
||||||
|
|
||||||
nxf = int(nxf)
|
|
||||||
nyf = int(nyf)
|
|
||||||
|
|
||||||
for i in range(nxf):
|
|
||||||
for j in range(nyf):
|
|
||||||
index_x_d = i * width_scale
|
|
||||||
index_x_u = (i + 1) * width_scale
|
|
||||||
|
|
||||||
index_y_d = j * height_scale
|
|
||||||
index_y_u = (j + 1) * height_scale
|
|
||||||
|
|
||||||
if index_x_u > img_w:
|
|
||||||
index_x_u = img_w
|
|
||||||
index_x_d = img_w - width_scale
|
|
||||||
if index_y_u > img_h:
|
|
||||||
index_y_u = img_h
|
|
||||||
index_y_d = img_h - height_scale
|
|
||||||
|
|
||||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
|
||||||
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
|
||||||
|
|
||||||
img_patch = resize_image(img_patch, height, width)
|
|
||||||
label_patch = resize_image(label_patch, height, width)
|
|
||||||
|
|
||||||
cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch)
|
|
||||||
cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch)
|
|
||||||
indexer += 1
|
|
||||||
|
|
||||||
return indexer
|
|
||||||
|
|
||||||
|
|
||||||
def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, indexer, scaler):
|
|
||||||
img = resize_image(img, int(img.shape[0] * scaler), int(img.shape[1] * scaler))
|
|
||||||
label = resize_image(label, int(label.shape[0] * scaler), int(label.shape[1] * scaler))
|
|
||||||
|
|
||||||
if img.shape[0] < height or img.shape[1] < width:
|
|
||||||
img, label = do_padding(img, label, height, width)
|
|
||||||
|
|
||||||
img_h = img.shape[0]
|
|
||||||
img_w = img.shape[1]
|
|
||||||
|
|
||||||
height_scale = int(height * 1)
|
|
||||||
width_scale = int(width * 1)
|
|
||||||
|
|
||||||
nxf = img_w / float(width_scale)
|
|
||||||
nyf = img_h / float(height_scale)
|
|
||||||
|
|
||||||
if nxf > int(nxf):
|
|
||||||
nxf = int(nxf) + 1
|
|
||||||
if nyf > int(nyf):
|
|
||||||
nyf = int(nyf) + 1
|
|
||||||
|
|
||||||
nxf = int(nxf)
|
|
||||||
nyf = int(nyf)
|
|
||||||
|
|
||||||
for i in range(nxf):
|
|
||||||
for j in range(nyf):
|
|
||||||
index_x_d = i * width_scale
|
|
||||||
index_x_u = (i + 1) * width_scale
|
|
||||||
|
|
||||||
index_y_d = j * height_scale
|
|
||||||
index_y_u = (j + 1) * height_scale
|
|
||||||
|
|
||||||
if index_x_u > img_w:
|
|
||||||
index_x_u = img_w
|
|
||||||
index_x_d = img_w - width_scale
|
|
||||||
if index_y_u > img_h:
|
|
||||||
index_y_u = img_h
|
|
||||||
index_y_d = img_h - height_scale
|
|
||||||
|
|
||||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
|
||||||
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
|
||||||
|
|
||||||
# img_patch=resize_image(img_patch,height,width)
|
|
||||||
# label_patch=resize_image(label_patch,height,width)
|
|
||||||
|
|
||||||
cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch)
|
|
||||||
cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch)
|
|
||||||
indexer += 1
|
|
||||||
|
|
||||||
return indexer
|
|
||||||
|
|
||||||
|
|
||||||
def provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
|
|
||||||
dir_flow_train_labels,
|
|
||||||
input_height, input_width, blur_k, blur_aug,
|
|
||||||
flip_aug, binarization, scaling, scales, flip_index,
|
|
||||||
scaling_bluring, scaling_binarization, rotation,
|
|
||||||
rotation_not_90, thetha, scaling_flip,
|
|
||||||
augmentation=False, patches=False):
|
|
||||||
imgs_cv_train = np.array(os.listdir(dir_img))
|
|
||||||
segs_cv_train = np.array(os.listdir(dir_seg))
|
|
||||||
|
|
||||||
indexer = 0
|
|
||||||
for im, seg_i in tqdm(zip(imgs_cv_train, segs_cv_train)):
|
|
||||||
img_name = im.split('.')[0]
|
|
||||||
if not patches:
|
|
||||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
|
||||||
resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width))
|
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
|
||||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
|
||||||
indexer += 1
|
|
||||||
|
|
||||||
if augmentation:
|
|
||||||
if flip_aug:
|
|
||||||
for f_i in flip_index:
|
|
||||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
|
||||||
resize_image(cv2.flip(cv2.imread(dir_img + '/' + im), f_i), input_height,
|
|
||||||
input_width))
|
|
||||||
|
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
|
||||||
resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
|
|
||||||
input_height, input_width))
|
|
||||||
indexer += 1
|
|
||||||
|
|
||||||
if blur_aug:
|
|
||||||
for blur_i in blur_k:
|
|
||||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
|
||||||
(resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height,
|
|
||||||
input_width)))
|
|
||||||
|
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
|
||||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height,
|
|
||||||
input_width))
|
|
||||||
indexer += 1
|
|
||||||
|
|
||||||
if binarization:
|
|
||||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
|
||||||
resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width))
|
|
||||||
|
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
|
||||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
|
||||||
indexer += 1
|
|
||||||
|
|
||||||
if patches:
|
|
||||||
|
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'),
|
|
||||||
input_height, input_width, indexer=indexer)
|
|
||||||
|
|
||||||
if augmentation:
|
|
||||||
|
|
||||||
if rotation:
|
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
rotation_90(cv2.imread(dir_img + '/' + im)),
|
|
||||||
rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
|
||||||
input_height, input_width, indexer=indexer)
|
|
||||||
|
|
||||||
if rotation_not_90:
|
|
||||||
|
|
||||||
for thetha_i in thetha:
|
|
||||||
img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/' + im),
|
|
||||||
cv2.imread(
|
|
||||||
dir_seg + '/' + img_name + '.png'),
|
|
||||||
thetha_i)
|
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
img_max_rotated,
|
|
||||||
label_max_rotated,
|
|
||||||
input_height, input_width, indexer=indexer)
|
|
||||||
if flip_aug:
|
|
||||||
for f_i in flip_index:
|
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
cv2.flip(cv2.imread(dir_img + '/' + im), f_i),
|
|
||||||
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
|
|
||||||
input_height, input_width, indexer=indexer)
|
|
||||||
if blur_aug:
|
|
||||||
for blur_i in blur_k:
|
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
|
||||||
input_height, input_width, indexer=indexer)
|
|
||||||
|
|
||||||
if scaling:
|
|
||||||
for sc_ind in scales:
|
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
cv2.imread(dir_img + '/' + im),
|
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
|
||||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
|
||||||
if binarization:
|
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
otsu_copy(cv2.imread(dir_img + '/' + im)),
|
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
|
||||||
input_height, input_width, indexer=indexer)
|
|
||||||
|
|
||||||
if scaling_bluring:
|
|
||||||
for sc_ind in scales:
|
|
||||||
for blur_i in blur_k:
|
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
|
||||||
input_height, input_width, indexer=indexer,
|
|
||||||
scaler=sc_ind)
|
|
||||||
|
|
||||||
if scaling_binarization:
|
|
||||||
for sc_ind in scales:
|
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
otsu_copy(cv2.imread(dir_img + '/' + im)),
|
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
|
||||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
|
||||||
|
|
||||||
if scaling_flip:
|
|
||||||
for sc_ind in scales:
|
|
||||||
for f_i in flip_index:
|
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
|
||||||
cv2.flip(cv2.imread(dir_img + '/' + im), f_i),
|
|
||||||
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'),
|
|
||||||
f_i),
|
|
||||||
input_height, input_width, indexer=indexer,
|
|
||||||
scaler=sc_ind)
|
|
Loading…
Add table
Add a link
Reference in a new issue