mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-30 22:20:02 +02:00
code formatting with black; typos
This commit is contained in:
parent
5f84938839
commit
02b1436f39
8 changed files with 741 additions and 768 deletions
|
@ -63,7 +63,7 @@ The output folder should be an empty folder where the output model will be writt
|
|||
* flip_aug: If ``true``, different types of filp will be applied on image. Type of flips is given with "flip_index" in train.py file.
|
||||
* blur_aug: If ``true``, different types of blurring will be applied on image. Type of blurrings is given with "blur_k" in train.py file.
|
||||
* scaling: If ``true``, scaling will be applied on image. Scale of scaling is given with "scales" in train.py file.
|
||||
* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rothation angles are given with "thetha" in train.py file.
|
||||
* rotation_not_90: If ``true``, rotation (not 90 degree) will be applied on image. Rotation angles are given with "thetha" in train.py file.
|
||||
* rotation: If ``true``, 90 degree rotation will be applied on image.
|
||||
* binarization: If ``true``,Otsu thresholding will be applied to augment the input data with binarized images.
|
||||
* scaling_bluring: If ``true``, combination of scaling and blurring will be applied on image.
|
||||
|
@ -73,5 +73,3 @@ The output folder should be an empty folder where the output model will be writt
|
|||
* weighted_loss: If ``true``, this means that you want to apply weighted categorical_crossentropy as loss fucntion. Be carefull if you set to ``true``the parameter "is_loss_soft_dice" should be ``false``
|
||||
* data_is_provided: If you have already provided the input data you can set this to ``true``. Be sure that the train and eval data are in "dir_output". Since when once we provide training data we resize and augment them and then we write them in sub-directories train and eval in "dir_output".
|
||||
* dir_train: This is the directory of "images" and "labels" (dir_train should include two subdirectories with names of images and labels ) for raw images and labels. Namely they are not prepared (not resized and not augmented) yet for training the model. When we run this tool these raw data will be transformed to suitable size needed for the model and they will be written in "dir_output" in train and eval directories. Each of train and eval include "images" and "labels" sub-directories.
|
||||
|
||||
|
||||
|
|
|
@ -9,8 +9,6 @@ from utils import *
|
|||
from metrics import *
|
||||
|
||||
|
||||
|
||||
|
||||
def configuration():
|
||||
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
|
||||
session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
|
||||
|
@ -29,5 +27,3 @@ if __name__=='__main__':
|
|||
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||
model.load_weights(dir_of_weights)
|
||||
model.save('./name_in_another_python_version.h5')
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
"weighted_loss": false,
|
||||
"is_loss_soft_dice": false,
|
||||
"data_is_provided": false,
|
||||
"dir_train": "/home/vahid/Documents/handwrittens_train/train",
|
||||
"dir_eval": "/home/vahid/Documents/handwrittens_train/eval",
|
||||
"dir_output": "/home/vahid/Documents/handwrittens_train/output"
|
||||
"dir_train": "/train",
|
||||
"dir_eval": "/eval",
|
||||
"dir_output": "/output"
|
||||
}
|
||||
|
|
43
metrics.py
43
metrics.py
|
@ -2,8 +2,8 @@ from tensorflow.keras import backend as K
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
def focal_loss(gamma=2., alpha=4.):
|
||||
|
||||
def focal_loss(gamma=2., alpha=4.):
|
||||
gamma = float(gamma)
|
||||
alpha = float(alpha)
|
||||
|
||||
|
@ -37,8 +37,10 @@ def focal_loss(gamma=2., alpha=4.):
|
|||
fl = tf.multiply(alpha, tf.multiply(weight, ce))
|
||||
reduced_fl = tf.reduce_max(fl, axis=1)
|
||||
return tf.reduce_mean(reduced_fl)
|
||||
|
||||
return focal_loss_fixed
|
||||
|
||||
|
||||
def weighted_categorical_crossentropy(weights=None):
|
||||
""" weighted_categorical_crossentropy
|
||||
|
||||
|
@ -58,7 +60,10 @@ def weighted_categorical_crossentropy(weights=None):
|
|||
* labels_floats, axis=-1), 1.0)
|
||||
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||
return tf.reduce_mean(per_pixel_loss)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def image_categorical_cross_entropy(y_true, y_pred, weights=None):
|
||||
"""
|
||||
:param y_true: tensor of shape (batch_size, height, width) representing the ground truth.
|
||||
|
@ -77,6 +82,8 @@ def image_categorical_cross_entropy(y_true, y_pred, weights=None):
|
|||
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||
|
||||
return tf.reduce_mean(per_pixel_loss)
|
||||
|
||||
|
||||
def class_tversky(y_true, y_pred):
|
||||
smooth = 1.0 # 1.00
|
||||
|
||||
|
@ -90,20 +97,22 @@ def class_tversky(y_true, y_pred):
|
|||
false_pos = K.sum((1 - y_true_pos) * y_pred_pos, 1)
|
||||
alpha = 0.2 # 0.5
|
||||
beta = 0.8
|
||||
return (true_pos + smooth)/(true_pos + alpha*false_neg + (beta)*false_pos + smooth)
|
||||
return (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
|
||||
|
||||
|
||||
def focal_tversky_loss(y_true, y_pred):
|
||||
pt_1 = class_tversky(y_true, y_pred)
|
||||
gamma = 1.3 # 4./3.0#1.3#4.0/3.00# 0.75
|
||||
return K.sum(K.pow((1 - pt_1), gamma))
|
||||
|
||||
|
||||
def generalized_dice_coeff2(y_true, y_pred):
|
||||
n_el = 1
|
||||
for dim in y_true.shape:
|
||||
n_el *= int(dim)
|
||||
n_cl = y_true.shape[-1]
|
||||
w = K.zeros(shape=(n_cl,))
|
||||
w = (K.sum(y_true, axis=(0,1,2)))/(n_el)
|
||||
w = (K.sum(y_true, axis=(0, 1, 2))) / n_el
|
||||
w = 1 / (w ** 2 + 0.000001)
|
||||
numerator = y_true * y_pred
|
||||
numerator = w * K.sum(numerator, (0, 1, 2))
|
||||
|
@ -112,6 +121,8 @@ def generalized_dice_coeff2(y_true, y_pred):
|
|||
denominator = w * K.sum(denominator, (0, 1, 2))
|
||||
denominator = K.sum(denominator)
|
||||
return 2 * numerator / denominator
|
||||
|
||||
|
||||
def generalized_dice_coeff(y_true, y_pred):
|
||||
axes = tuple(range(1, len(y_pred.shape) - 1))
|
||||
Ncl = y_pred.shape[-1]
|
||||
|
@ -131,10 +142,13 @@ def generalized_dice_coeff(y_true, y_pred):
|
|||
|
||||
return gen_dice_coef
|
||||
|
||||
|
||||
def generalized_dice_loss(y_true, y_pred):
|
||||
return 1 - generalized_dice_coeff2(y_true, y_pred)
|
||||
|
||||
|
||||
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
||||
'''
|
||||
"""
|
||||
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
||||
Assumes the `channels_last` format.
|
||||
|
||||
|
@ -150,7 +164,7 @@ def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
|||
https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
|
||||
|
||||
Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
|
||||
'''
|
||||
"""
|
||||
|
||||
# skip the batch and class axis for calculating Dice score
|
||||
axes = tuple(range(1, len(y_pred.shape) - 1))
|
||||
|
@ -160,7 +174,9 @@ def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
|||
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
||||
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||
|
||||
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = True, mean_per_class=False, verbose=False):
|
||||
|
||||
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
|
||||
verbose=False):
|
||||
"""
|
||||
Compute mean metrics of two segmentation masks, via Keras.
|
||||
|
||||
|
@ -250,6 +266,7 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last =
|
|||
|
||||
return K.mean(non_zero_sum / non_zero_count)
|
||||
|
||||
|
||||
def mean_iou(y_true, y_pred, **kwargs):
|
||||
"""
|
||||
Compute mean Intersection over Union of two segmentation masks, via Keras.
|
||||
|
@ -257,6 +274,8 @@ def mean_iou(y_true, y_pred, **kwargs):
|
|||
Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs.
|
||||
"""
|
||||
return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs)
|
||||
|
||||
|
||||
def Mean_IOU(y_true, y_pred):
|
||||
nb_classes = K.int_shape(y_pred)[-1]
|
||||
iou = []
|
||||
|
@ -276,6 +295,7 @@ def Mean_IOU(y_true, y_pred):
|
|||
iou = tf.gather(iou, indices=tf.where(legal_labels))
|
||||
return K.mean(iou)
|
||||
|
||||
|
||||
def iou_vahid(y_true, y_pred):
|
||||
nb_classes = tf.shape(y_true)[-1] + tf.to_int32(1)
|
||||
true_pixels = K.argmax(y_true, axis=-1)
|
||||
|
@ -292,8 +312,8 @@ def iou_vahid(y_true, y_pred):
|
|||
|
||||
|
||||
def IoU_metric(Yi, y_predi):
|
||||
## mean Intersection over Union
|
||||
## Mean IoU = TP/(FN + TP + FP)
|
||||
# mean Intersection over Union
|
||||
# Mean IoU = TP/(FN + TP + FP)
|
||||
y_predi = np.argmax(y_predi, axis=3)
|
||||
y_testi = np.argmax(Yi, axis=3)
|
||||
IoUs = []
|
||||
|
@ -308,14 +328,15 @@ def IoU_metric(Yi,y_predi):
|
|||
|
||||
|
||||
def IoU_metric_keras(y_true, y_pred):
|
||||
## mean Intersection over Union
|
||||
## Mean IoU = TP/(FN + TP + FP)
|
||||
# mean Intersection over Union
|
||||
# Mean IoU = TP/(FN + TP + FP)
|
||||
init = tf.global_variables_initializer()
|
||||
sess = tf.Session()
|
||||
sess.run(init)
|
||||
|
||||
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
|
||||
|
||||
|
||||
def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
||||
"""
|
||||
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
|
||||
|
@ -334,5 +355,3 @@ def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
|||
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
|
||||
jac = (intersection + smooth) / (sum_ - intersection + smooth)
|
||||
return (1 - jac) * smooth
|
||||
|
||||
|
||||
|
|
39
models.py
39
models.py
|
@ -16,6 +16,7 @@ def one_side_pad( x ):
|
|||
x = Lambda(lambda x: x[:, :-1, :-1, :])(x)
|
||||
return x
|
||||
|
||||
|
||||
def identity_block(input_tensor, kernel_size, filters, stage, block):
|
||||
"""The identity block is the block that has no conv layer at shortcut.
|
||||
# Arguments
|
||||
|
@ -103,7 +104,6 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
assert input_height % 32 == 0
|
||||
assert input_width % 32 == 0
|
||||
|
||||
|
||||
img_input = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
|
@ -112,20 +112,19 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||
name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
|
@ -145,22 +144,17 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
|
||||
if pretraining:
|
||||
model = Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
|
||||
v512_2048 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f5)
|
||||
v512_2048 = (BatchNormalization(axis=bn_axis))(v512_2048)
|
||||
v512_2048 = Activation('relu')(v512_2048)
|
||||
|
||||
|
||||
|
||||
v512_1024 = Conv2D(512, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(f4)
|
||||
v512_1024 = (BatchNormalization(axis=bn_axis))(v512_1024)
|
||||
v512_1024 = Activation('relu')(v512_1024)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v512_2048)
|
||||
o = (concatenate([o, v512_1024], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -168,7 +162,6 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -176,7 +169,6 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -184,8 +176,6 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -193,7 +183,6 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -201,21 +190,18 @@ def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
|
||||
model = Model(img_input, o)
|
||||
return model
|
||||
|
||||
|
||||
def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
||||
assert input_height % 32 == 0
|
||||
assert input_width % 32 == 0
|
||||
|
||||
|
||||
img_input = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
|
@ -224,20 +210,19 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2), kernel_regularizer=l2(weight_decay),
|
||||
name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
|
@ -260,11 +245,11 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
if pretraining:
|
||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 )
|
||||
v1024_2048 = Conv2D(1024, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(
|
||||
f5)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Activation('relu')(v1024_2048)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = (concatenate([o, f4], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -272,7 +257,6 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -280,7 +264,6 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -288,7 +271,6 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -296,7 +278,6 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, img_input], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
|
@ -304,14 +285,10 @@ def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,p
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
model = Model(img_input, o)
|
||||
|
||||
|
||||
|
||||
|
||||
return model
|
||||
|
|
|
@ -4,3 +4,5 @@ opencv-python-headless
|
|||
seaborn
|
||||
tqdm
|
||||
imutils
|
||||
numpy
|
||||
scipy
|
||||
|
|
34
train.py
34
train.py
|
@ -11,12 +11,14 @@ from metrics import *
|
|||
from tensorflow.keras.models import load_model
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def configuration():
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
session = tf.compat.v1.Session(config=config)
|
||||
set_session(session)
|
||||
|
||||
|
||||
def get_dirs_or_files(input_data):
|
||||
if os.path.isdir(input_data):
|
||||
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
||||
|
@ -25,8 +27,10 @@ def get_dirs_or_files(input_data):
|
|||
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
|
||||
return image_input, labels_input
|
||||
|
||||
|
||||
ex = Experiment()
|
||||
|
||||
|
||||
@ex.config
|
||||
def config_params():
|
||||
n_classes = None # Number of classes. In the case of binary classification this should be 2.
|
||||
|
@ -60,18 +64,17 @@ def config_params():
|
|||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
||||
|
||||
|
||||
@ex.automain
|
||||
def run(n_classes, n_epochs, input_height,
|
||||
input_width, weight_decay, weighted_loss,
|
||||
index_start, dir_of_start_model, is_loss_soft_dice,
|
||||
n_batch,patches,augmentation,flip_aug
|
||||
,blur_aug,scaling, binarization,
|
||||
n_batch, patches, augmentation, flip_aug,
|
||||
blur_aug, scaling, binarization,
|
||||
blur_k, scales, dir_train, data_is_provided,
|
||||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip, continue_training,
|
||||
flip_index, dir_eval, dir_output, pretraining, learning_rate):
|
||||
|
||||
|
||||
if data_is_provided:
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
@ -110,18 +113,15 @@ def run(n_classes,n_epochs,input_height,
|
|||
else:
|
||||
os.makedirs(dir_eval_flowing)
|
||||
|
||||
|
||||
os.mkdir(dir_flow_train_imgs)
|
||||
os.mkdir(dir_flow_train_labels)
|
||||
|
||||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
|
||||
# set the gpu configuration
|
||||
configuration()
|
||||
|
||||
|
||||
# writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
|
@ -139,8 +139,6 @@ def run(n_classes,n_epochs,input_height,
|
|||
rotation_not_90, thetha, scaling_flip,
|
||||
augmentation=False, patches=patches)
|
||||
|
||||
|
||||
|
||||
if weighted_loss:
|
||||
weights = np.zeros(n_classes)
|
||||
if data_is_provided:
|
||||
|
@ -161,20 +159,18 @@ def run(n_classes,n_epochs,input_height,
|
|||
except:
|
||||
pass
|
||||
|
||||
|
||||
weights = 1.00 / weights
|
||||
|
||||
weights = weights / float(np.sum(weights))
|
||||
weights = weights / float(np.min(weights))
|
||||
weights = weights / float(np.sum(weights))
|
||||
|
||||
|
||||
|
||||
if continue_training:
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
model = load_model(dir_of_start_model, compile=True,
|
||||
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True)
|
||||
else:
|
||||
|
@ -185,7 +181,6 @@ def run(n_classes,n_epochs,input_height,
|
|||
# if you want to see the model structure just uncomment model summary.
|
||||
# model.summary()
|
||||
|
||||
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
@ -212,18 +207,7 @@ def run(n_classes,n_epochs,input_height,
|
|||
epochs=1)
|
||||
model.save(dir_output + '/' + 'model_' + str(i))
|
||||
|
||||
|
||||
# os.system('rm -rf '+dir_train_flowing)
|
||||
# os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
# model.save(dir_output+'/'+'model'+'.h5')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
91
utils.py
91
utils.py
|
@ -10,9 +10,8 @@ import imutils
|
|||
import math
|
||||
|
||||
|
||||
|
||||
def bluring(img_in, kind):
|
||||
if kind=='guass':
|
||||
if kind == 'gauss':
|
||||
img_blur = cv2.GaussianBlur(img_in, (5, 5), 0)
|
||||
elif kind == "median":
|
||||
img_blur = cv2.medianBlur(img_in, 5)
|
||||
|
@ -20,8 +19,8 @@ def bluring(img_in,kind):
|
|||
img_blur = cv2.blur(img_in, (5, 5))
|
||||
return img_blur
|
||||
|
||||
def elastic_transform(image, alpha, sigma,seedj, random_state=None):
|
||||
|
||||
def elastic_transform(image, alpha, sigma, seedj, random_state=None):
|
||||
"""Elastic deformation of images as described in [Simard2003]_.
|
||||
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
||||
Convolutional Neural Networks applied to Visual Document Analysis", in
|
||||
|
@ -42,6 +41,7 @@ def elastic_transform(image, alpha, sigma,seedj, random_state=None):
|
|||
distored_image = map_coordinates(image, indices, order=1, mode='reflect')
|
||||
return distored_image.reshape(image.shape)
|
||||
|
||||
|
||||
def rotation_90(img):
|
||||
img_rot = np.zeros((img.shape[1], img.shape[0], img.shape[2]))
|
||||
img_rot[:, :, 0] = img[:, :, 0].T
|
||||
|
@ -49,6 +49,7 @@ def rotation_90(img):
|
|||
img_rot[:, :, 2] = img[:, :, 2].T
|
||||
return img_rot
|
||||
|
||||
|
||||
def rotatedRectWithMaxArea(w, h, angle):
|
||||
"""
|
||||
Given a rectangle of size wxh that has been rotated by 'angle' (in
|
||||
|
@ -76,6 +77,7 @@ def rotatedRectWithMaxArea(w, h, angle):
|
|||
|
||||
return wr, hr
|
||||
|
||||
|
||||
def rotate_max_area(image, rotated, rotated_label, angle):
|
||||
""" image: cv2 image matrix object
|
||||
angle: in degree
|
||||
|
@ -88,11 +90,14 @@ def rotate_max_area(image,rotated, rotated_label,angle):
|
|||
x1 = w // 2 - int(wr / 2)
|
||||
x2 = x1 + int(wr)
|
||||
return rotated[y1:y2, x1:x2], rotated_label[y1:y2, x1:x2]
|
||||
|
||||
|
||||
def rotation_not_90_func(img, label, thetha):
|
||||
rotated = imutils.rotate(img, thetha)
|
||||
rotated_label = imutils.rotate(label, thetha)
|
||||
return rotate_max_area(img, rotated, rotated_label, thetha)
|
||||
|
||||
|
||||
def color_images(seg, n_classes):
|
||||
ann_u = range(n_classes)
|
||||
if len(np.shape(seg)) == 3:
|
||||
|
@ -112,6 +117,8 @@ def color_images(seg, n_classes):
|
|||
|
||||
def resize_image(seg_in, input_height, input_width):
|
||||
return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
|
||||
def get_one_hot(seg, input_height, input_width, n_classes):
|
||||
seg = seg[:, :, 0]
|
||||
seg_f = np.zeros((input_height, input_width, n_classes))
|
||||
|
@ -137,6 +144,8 @@ def IoU(Yi,y_predi):
|
|||
print("_________________")
|
||||
print("Mean IoU: {:4.3f}".format(mIoU))
|
||||
return mIoU
|
||||
|
||||
|
||||
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes):
|
||||
c = 0
|
||||
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
|
||||
|
@ -152,13 +161,15 @@ def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_cla
|
|||
filename = n[i].split('.')[0]
|
||||
|
||||
train_img = cv2.imread(img_folder + '/' + n[i]) / 255.
|
||||
train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize
|
||||
train_img = cv2.resize(train_img, (input_width, input_height),
|
||||
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
||||
|
||||
img[i - c] = train_img # add to array - img[0], img[1], and so on.
|
||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
||||
# print(mask_folder+'/'+filename+'.png')
|
||||
# print(train_mask.shape)
|
||||
train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes)
|
||||
train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width,
|
||||
n_classes)
|
||||
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
||||
|
||||
mask[i - c] = train_mask
|
||||
|
@ -166,14 +177,13 @@ def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_cla
|
|||
img[i - c] = np.ones((input_height, input_width, 3)).astype('float')
|
||||
mask[i - c] = np.zeros((input_height, input_width, n_classes)).astype('float')
|
||||
|
||||
|
||||
|
||||
c += batch_size
|
||||
if(c+batch_size>=len(os.listdir(img_folder))):
|
||||
if c + batch_size >= len(os.listdir(img_folder)):
|
||||
c = 0
|
||||
random.shuffle(n)
|
||||
yield img, mask
|
||||
|
||||
|
||||
def otsu_copy(img):
|
||||
img_r = np.zeros(img.shape)
|
||||
img1 = img[:, :, 0]
|
||||
|
@ -186,8 +196,9 @@ def otsu_copy(img):
|
|||
img_r[:, :, 1] = threshold1
|
||||
img_r[:, :, 2] = threshold1
|
||||
return img_r
|
||||
def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer):
|
||||
|
||||
|
||||
def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer):
|
||||
if img.shape[0] < height or img.shape[1] < width:
|
||||
img, label = do_padding(img, label, height, width)
|
||||
|
||||
|
@ -220,7 +231,6 @@ def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer):
|
|||
index_y_u = img_h
|
||||
index_y_d = img_h - height
|
||||
|
||||
|
||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
|
||||
|
@ -230,8 +240,8 @@ def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer):
|
|||
|
||||
return indexer
|
||||
|
||||
def do_padding(img,label,height,width):
|
||||
|
||||
def do_padding(img, label, height, width):
|
||||
height_new = img.shape[0]
|
||||
width_new = img.shape[1]
|
||||
|
||||
|
@ -256,8 +266,6 @@ def do_padding(img,label,height,width):
|
|||
|
||||
|
||||
def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler):
|
||||
|
||||
|
||||
if img.shape[0] < height or img.shape[1] < width:
|
||||
img, label = do_padding(img, label, height, width)
|
||||
|
||||
|
@ -267,7 +275,6 @@ def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,n_p
|
|||
height_scale = int(height * scaler)
|
||||
width_scale = int(width * scaler)
|
||||
|
||||
|
||||
nxf = img_w / float(width_scale)
|
||||
nyf = img_h / float(height_scale)
|
||||
|
||||
|
@ -294,7 +301,6 @@ def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,n_p
|
|||
index_y_u = img_h
|
||||
index_y_d = img_h - height_scale
|
||||
|
||||
|
||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
|
||||
|
@ -307,6 +313,7 @@ def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,n_p
|
|||
|
||||
return indexer
|
||||
|
||||
|
||||
def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, indexer, scaler):
|
||||
img = resize_image(img, int(img.shape[0] * scaler), int(img.shape[1] * scaler))
|
||||
label = resize_image(label, int(label.shape[0] * scaler), int(label.shape[1] * scaler))
|
||||
|
@ -320,7 +327,6 @@ def get_patches_num_scale_new(dir_img_f,dir_seg_f,img,label,height,width,indexer
|
|||
height_scale = int(height * 1)
|
||||
width_scale = int(width * 1)
|
||||
|
||||
|
||||
nxf = img_w / float(width_scale)
|
||||
nyf = img_h / float(height_scale)
|
||||
|
||||
|
@ -347,7 +353,6 @@ def get_patches_num_scale_new(dir_img_f,dir_seg_f,img,label,height,width,indexer
|
|||
index_y_u = img_h
|
||||
index_y_d = img_h - height_scale
|
||||
|
||||
|
||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
|
||||
|
@ -368,7 +373,6 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip,
|
||||
augmentation=False, patches=False):
|
||||
|
||||
imgs_cv_train = np.array(os.listdir(dir_img))
|
||||
segs_cv_train = np.array(os.listdir(dir_seg))
|
||||
|
||||
|
@ -376,30 +380,35 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
for im, seg_i in tqdm(zip(imgs_cv_train, segs_cv_train)):
|
||||
img_name = im.split('.')[0]
|
||||
if not patches:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) )
|
||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) )
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width))
|
||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
||||
indexer += 1
|
||||
|
||||
if augmentation:
|
||||
if flip_aug:
|
||||
for f_i in flip_index:
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) )
|
||||
resize_image(cv2.flip(cv2.imread(dir_img + '/' + im), f_i), input_height,
|
||||
input_width))
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png'),f_i),input_height,input_width) )
|
||||
resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
|
||||
input_height, input_width))
|
||||
indexer += 1
|
||||
|
||||
if blur_aug:
|
||||
for blur_i in blur_k:
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||
(resize_image(bluring(cv2.imread(dir_img+'/'+im),blur_i),input_height,input_width) ) )
|
||||
(resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height,
|
||||
input_width)))
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width) )
|
||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height,
|
||||
input_width))
|
||||
indexer += 1
|
||||
|
||||
|
||||
if binarization:
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||
resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width))
|
||||
|
@ -408,11 +417,6 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
||||
indexer += 1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if patches:
|
||||
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
|
@ -422,8 +426,6 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
if augmentation:
|
||||
|
||||
if rotation:
|
||||
|
||||
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
rotation_90(cv2.imread(dir_img + '/' + im)),
|
||||
rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
||||
|
@ -432,7 +434,10 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
if rotation_not_90:
|
||||
|
||||
for thetha_i in thetha:
|
||||
img_max_rotated,label_max_rotated=rotation_not_90_func(cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'),thetha_i)
|
||||
img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/' + im),
|
||||
cv2.imread(
|
||||
dir_seg + '/' + img_name + '.png'),
|
||||
thetha_i)
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
img_max_rotated,
|
||||
label_max_rotated,
|
||||
|
@ -445,13 +450,11 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
input_height, input_width, indexer=indexer)
|
||||
if blur_aug:
|
||||
for blur_i in blur_k:
|
||||
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
|
||||
if scaling:
|
||||
for sc_ind in scales:
|
||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
|
@ -464,15 +467,14 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
|
||||
|
||||
if scaling_bluring:
|
||||
for sc_ind in scales:
|
||||
for blur_i in blur_k:
|
||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
input_height, input_width, indexer=indexer,
|
||||
scaler=sc_ind)
|
||||
|
||||
if scaling_binarization:
|
||||
for sc_ind in scales:
|
||||
|
@ -486,12 +488,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
for f_i in flip_index:
|
||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
cv2.flip(cv2.imread(dir_img + '/' + im), f_i),
|
||||
cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i) ,
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
f_i),
|
||||
input_height, input_width, indexer=indexer,
|
||||
scaler=sc_ind)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue