📝 howto: Be more verbose with the subtree pull
commit
4897fd3dd7
@ -0,0 +1,26 @@
|
||||
# Train
|
||||
just run: python train.py with config_params.json
|
||||
|
||||
|
||||
# Ground truth format
|
||||
|
||||
Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and
|
||||
labels should be 0 and 1 for each class and pixel.
|
||||
In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels
|
||||
by pixels set from 0 , 1 ,2 .., n_classes-1.
|
||||
The labels format should be png.
|
||||
|
||||
If you have an image label for binary case it should look like this:
|
||||
|
||||
Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ]
|
||||
this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0.
|
||||
|
||||
# Training , evaluation and output
|
||||
train and evaluation folder should have subfolder of images and labels.
|
||||
And output folder should be empty folder which the output model will be written there.
|
||||
|
||||
# Patches
|
||||
|
||||
if you want to train your model with patches, the height and width of patches should be defined and also number of
|
||||
batchs (how many patches should be seen by model by each iteration).
|
||||
In the case that model should see the image once, like page extraction, the patches should be set to false.
|
@ -0,0 +1,24 @@
|
||||
{
|
||||
"n_classes" : 2,
|
||||
"n_epochs" : 2,
|
||||
"input_height" : 448,
|
||||
"input_width" : 896,
|
||||
"weight_decay" : 1e-6,
|
||||
"n_batch" : 1,
|
||||
"learning_rate": 1e-4,
|
||||
"patches" : true,
|
||||
"pretraining" : true,
|
||||
"augmentation" : false,
|
||||
"flip_aug" : false,
|
||||
"elastic_aug" : false,
|
||||
"blur_aug" : false,
|
||||
"scaling" : false,
|
||||
"binarization" : false,
|
||||
"scaling_bluring" : false,
|
||||
"scaling_binarization" : false,
|
||||
"rotation": false,
|
||||
"weighted_loss": true,
|
||||
"dir_train": "../train",
|
||||
"dir_eval": "../eval",
|
||||
"dir_output": "../output"
|
||||
}
|
@ -0,0 +1,338 @@
|
||||
from keras import backend as K
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
def focal_loss(gamma=2., alpha=4.):
|
||||
|
||||
gamma = float(gamma)
|
||||
alpha = float(alpha)
|
||||
|
||||
def focal_loss_fixed(y_true, y_pred):
|
||||
"""Focal loss for multi-classification
|
||||
FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
|
||||
Notice: y_pred is probability after softmax
|
||||
gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
|
||||
d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
|
||||
Focal Loss for Dense Object Detection
|
||||
https://arxiv.org/abs/1708.02002
|
||||
|
||||
Arguments:
|
||||
y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls]
|
||||
y_pred {tensor} -- model's output, shape of [batch_size, num_cls]
|
||||
|
||||
Keyword Arguments:
|
||||
gamma {float} -- (default: {2.0})
|
||||
alpha {float} -- (default: {4.0})
|
||||
|
||||
Returns:
|
||||
[tensor] -- loss.
|
||||
"""
|
||||
epsilon = 1.e-9
|
||||
y_true = tf.convert_to_tensor(y_true, tf.float32)
|
||||
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
|
||||
|
||||
model_out = tf.add(y_pred, epsilon)
|
||||
ce = tf.multiply(y_true, -tf.log(model_out))
|
||||
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
|
||||
fl = tf.multiply(alpha, tf.multiply(weight, ce))
|
||||
reduced_fl = tf.reduce_max(fl, axis=1)
|
||||
return tf.reduce_mean(reduced_fl)
|
||||
return focal_loss_fixed
|
||||
|
||||
def weighted_categorical_crossentropy(weights=None):
|
||||
""" weighted_categorical_crossentropy
|
||||
|
||||
Args:
|
||||
* weights<ktensor|nparray|list>: crossentropy weights
|
||||
Returns:
|
||||
* weighted categorical crossentropy function
|
||||
"""
|
||||
|
||||
def loss(y_true, y_pred):
|
||||
labels_floats = tf.cast(y_true, tf.float32)
|
||||
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
||||
|
||||
if weights is not None:
|
||||
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
||||
np.array(weights, dtype=np.float32)[None, None, None])
|
||||
* labels_floats, axis=-1), 1.0)
|
||||
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||
return tf.reduce_mean(per_pixel_loss)
|
||||
return loss
|
||||
def image_categorical_cross_entropy(y_true, y_pred, weights=None):
|
||||
"""
|
||||
:param y_true: tensor of shape (batch_size, height, width) representing the ground truth.
|
||||
:param y_pred: tensor of shape (batch_size, height, width) representing the prediction.
|
||||
:return: The mean cross-entropy on softmaxed tensors.
|
||||
"""
|
||||
|
||||
labels_floats = tf.cast(y_true, tf.float32)
|
||||
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
||||
|
||||
if weights is not None:
|
||||
weight_mask = tf.maximum(
|
||||
tf.reduce_max(tf.constant(
|
||||
np.array(weights, dtype=np.float32)[None, None, None])
|
||||
* labels_floats, axis=-1), 1.0)
|
||||
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||
|
||||
return tf.reduce_mean(per_pixel_loss)
|
||||
def class_tversky(y_true, y_pred):
|
||||
smooth = 1.0#1.00
|
||||
|
||||
y_true = K.permute_dimensions(y_true, (3,1,2,0))
|
||||
y_pred = K.permute_dimensions(y_pred, (3,1,2,0))
|
||||
|
||||
y_true_pos = K.batch_flatten(y_true)
|
||||
y_pred_pos = K.batch_flatten(y_pred)
|
||||
true_pos = K.sum(y_true_pos * y_pred_pos, 1)
|
||||
false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1)
|
||||
false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1)
|
||||
alpha = 0.2#0.5
|
||||
beta=0.8
|
||||
return (true_pos + smooth)/(true_pos + alpha*false_neg + (beta)*false_pos + smooth)
|
||||
|
||||
def focal_tversky_loss(y_true,y_pred):
|
||||
pt_1 = class_tversky(y_true, y_pred)
|
||||
gamma =1.3#4./3.0#1.3#4.0/3.00# 0.75
|
||||
return K.sum(K.pow((1-pt_1), gamma))
|
||||
|
||||
def generalized_dice_coeff2(y_true, y_pred):
|
||||
n_el = 1
|
||||
for dim in y_true.shape:
|
||||
n_el *= int(dim)
|
||||
n_cl = y_true.shape[-1]
|
||||
w = K.zeros(shape=(n_cl,))
|
||||
w = (K.sum(y_true, axis=(0,1,2)))/(n_el)
|
||||
w = 1/(w**2+0.000001)
|
||||
numerator = y_true*y_pred
|
||||
numerator = w*K.sum(numerator,(0,1,2))
|
||||
numerator = K.sum(numerator)
|
||||
denominator = y_true+y_pred
|
||||
denominator = w*K.sum(denominator,(0,1,2))
|
||||
denominator = K.sum(denominator)
|
||||
return 2*numerator/denominator
|
||||
def generalized_dice_coeff(y_true, y_pred):
|
||||
axes = tuple(range(1, len(y_pred.shape)-1))
|
||||
Ncl = y_pred.shape[-1]
|
||||
w = K.zeros(shape=(Ncl,))
|
||||
w = K.sum(y_true, axis=axes)
|
||||
w = 1/(w**2+0.000001)
|
||||
# Compute gen dice coef:
|
||||
numerator = y_true*y_pred
|
||||
numerator = w*K.sum(numerator,axes)
|
||||
numerator = K.sum(numerator)
|
||||
|
||||
denominator = y_true+y_pred
|
||||
denominator = w*K.sum(denominator,axes)
|
||||
denominator = K.sum(denominator)
|
||||
|
||||
gen_dice_coef = 2*numerator/denominator
|
||||
|
||||
return gen_dice_coef
|
||||
|
||||
def generalized_dice_loss(y_true, y_pred):
|
||||
return 1 - generalized_dice_coeff2(y_true, y_pred)
|
||||
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
||||
'''
|
||||
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
||||
Assumes the `channels_last` format.
|
||||
|
||||
# Arguments
|
||||
y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
|
||||
y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax)
|
||||
epsilon: Used for numerical stability to avoid divide by zero errors
|
||||
|
||||
# References
|
||||
V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
|
||||
https://arxiv.org/abs/1606.04797
|
||||
More details on Dice loss formulation
|
||||
https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
|
||||
|
||||
Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
|
||||
'''
|
||||
|
||||
# skip the batch and class axis for calculating Dice score
|
||||
axes = tuple(range(1, len(y_pred.shape)-1))
|
||||
|
||||
numerator = 2. * K.sum(y_pred * y_true, axes)
|
||||
|
||||
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
||||
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||
|
||||
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = True, mean_per_class=False, verbose=False):
|
||||
"""
|
||||
Compute mean metrics of two segmentation masks, via Keras.
|
||||
|
||||
IoU(A,B) = |A & B| / (| A U B|)
|
||||
Dice(A,B) = 2*|A & B| / (|A| + |B|)
|
||||
|
||||
Args:
|
||||
y_true: true masks, one-hot encoded.
|
||||
y_pred: predicted masks, either softmax outputs, or one-hot encoded.
|
||||
metric_name: metric to be computed, either 'iou' or 'dice'.
|
||||
metric_type: one of 'standard' (default), 'soft', 'naive'.
|
||||
In the standard version, y_pred is one-hot encoded and the mean
|
||||
is taken only over classes that are present (in y_true or y_pred).
|
||||
The 'soft' version of the metrics are computed without one-hot
|
||||
encoding y_pred.
|
||||
The 'naive' version return mean metrics where absent classes contribute
|
||||
to the class mean as 1.0 (instead of being dropped from the mean).
|
||||
drop_last = True: boolean flag to drop last class (usually reserved
|
||||
for background class in semantic segmentation)
|
||||
mean_per_class = False: return mean along batch axis for each class.
|
||||
verbose = False: print intermediate results such as intersection, union
|
||||
(as number of pixels).
|
||||
Returns:
|
||||
IoU/Dice of y_true and y_pred, as a float, unless mean_per_class == True
|
||||
in which case it returns the per-class metric, averaged over the batch.
|
||||
|
||||
Inputs are B*W*H*N tensors, with
|
||||
B = batch size,
|
||||
W = width,
|
||||
H = height,
|
||||
N = number of classes
|
||||
"""
|
||||
|
||||
flag_soft = (metric_type == 'soft')
|
||||
flag_naive_mean = (metric_type == 'naive')
|
||||
|
||||
# always assume one or more classes
|
||||
num_classes = K.shape(y_true)[-1]
|
||||
|
||||
if not flag_soft:
|
||||
# get one-hot encoded masks from y_pred (true masks should already be one-hot)
|
||||
y_pred = K.one_hot(K.argmax(y_pred), num_classes)
|
||||
y_true = K.one_hot(K.argmax(y_true), num_classes)
|
||||
|
||||
# if already one-hot, could have skipped above command
|
||||
# keras uses float32 instead of float64, would give error down (but numpy arrays or keras.to_categorical gives float64)
|
||||
y_true = K.cast(y_true, 'float32')
|
||||
y_pred = K.cast(y_pred, 'float32')
|
||||
|
||||
# intersection and union shapes are batch_size * n_classes (values = area in pixels)
|
||||
axes = (1,2) # W,H axes of each image
|
||||
intersection = K.sum(K.abs(y_true * y_pred), axis=axes)
|
||||
mask_sum = K.sum(K.abs(y_true), axis=axes) + K.sum(K.abs(y_pred), axis=axes)
|
||||
union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot
|
||||
|
||||
smooth = .001
|
||||
iou = (intersection + smooth) / (union + smooth)
|
||||
dice = 2 * (intersection + smooth)/(mask_sum + smooth)
|
||||
|
||||
metric = {'iou': iou, 'dice': dice}[metric_name]
|
||||
|
||||
# define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise
|
||||
mask = K.cast(K.not_equal(union, 0), 'float32')
|
||||
|
||||
if drop_last:
|
||||
metric = metric[:,:-1]
|
||||
mask = mask[:,:-1]
|
||||
|
||||
if verbose:
|
||||
print('intersection, union')
|
||||
print(K.eval(intersection), K.eval(union))
|
||||
print(K.eval(intersection/union))
|
||||
|
||||
# return mean metrics: remaining axes are (batch, classes)
|
||||
if flag_naive_mean:
|
||||
return K.mean(metric)
|
||||
|
||||
# take mean only over non-absent classes
|
||||
class_count = K.sum(mask, axis=0)
|
||||
non_zero = tf.greater(class_count, 0)
|
||||
non_zero_sum = tf.boolean_mask(K.sum(metric * mask, axis=0), non_zero)
|
||||
non_zero_count = tf.boolean_mask(class_count, non_zero)
|
||||
|
||||
if verbose:
|
||||
print('Counts of inputs with class present, metrics for non-absent classes')
|
||||
print(K.eval(class_count), K.eval(non_zero_sum / non_zero_count))
|
||||
|
||||
return K.mean(non_zero_sum / non_zero_count)
|
||||
|
||||
def mean_iou(y_true, y_pred, **kwargs):
|
||||
"""
|
||||
Compute mean Intersection over Union of two segmentation masks, via Keras.
|
||||
|
||||
Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs.
|
||||
"""
|
||||
return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs)
|
||||
def Mean_IOU(y_true, y_pred):
|
||||
nb_classes = K.int_shape(y_pred)[-1]
|
||||
iou = []
|
||||
true_pixels = K.argmax(y_true, axis=-1)
|
||||
pred_pixels = K.argmax(y_pred, axis=-1)
|
||||
void_labels = K.equal(K.sum(y_true, axis=-1), 0)
|
||||
for i in range(0, nb_classes): # exclude first label (background) and last label (void)
|
||||
true_labels = K.equal(true_pixels, i)# & ~void_labels
|
||||
pred_labels = K.equal(pred_pixels, i)# & ~void_labels
|
||||
inter = tf.to_int32(true_labels & pred_labels)
|
||||
union = tf.to_int32(true_labels | pred_labels)
|
||||
legal_batches = K.sum(tf.to_int32(true_labels), axis=1)>0
|
||||
ious = K.sum(inter, axis=1)/K.sum(union, axis=1)
|
||||
iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects
|
||||
iou = tf.stack(iou)
|
||||
legal_labels = ~tf.debugging.is_nan(iou)
|
||||
iou = tf.gather(iou, indices=tf.where(legal_labels))
|
||||
return K.mean(iou)
|
||||
|
||||
def iou_vahid(y_true, y_pred):
|
||||
nb_classes = tf.shape(y_true)[-1]+tf.to_int32(1)
|
||||
true_pixels = K.argmax(y_true, axis=-1)
|
||||
pred_pixels = K.argmax(y_pred, axis=-1)
|
||||
iou = []
|
||||
|
||||
for i in tf.range(nb_classes):
|
||||
tp=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.equal(pred_pixels, i) ) )
|
||||
fp=K.sum( tf.to_int32( K.not_equal(true_pixels, i) & K.equal(pred_pixels, i) ) )
|
||||
fn=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.not_equal(pred_pixels, i) ) )
|
||||
iouh=tp/(tp+fp+fn)
|
||||
iou.append(iouh)
|
||||
return K.mean(iou)
|
||||
|
||||
|
||||
def IoU_metric(Yi,y_predi):
|
||||
## mean Intersection over Union
|
||||
## Mean IoU = TP/(FN + TP + FP)
|
||||
y_predi = np.argmax(y_predi, axis=3)
|
||||
y_testi = np.argmax(Yi, axis=3)
|
||||
IoUs = []
|
||||
Nclass = int(np.max(Yi)) + 1
|
||||
for c in range(Nclass):
|
||||
TP = np.sum( (Yi == c)&(y_predi==c) )
|
||||
FP = np.sum( (Yi != c)&(y_predi==c) )
|
||||
FN = np.sum( (Yi == c)&(y_predi != c))
|
||||
IoU = TP/float(TP + FP + FN)
|
||||
IoUs.append(IoU)
|
||||
return K.cast( np.mean(IoUs) ,dtype='float32' )
|
||||
|
||||
|
||||
def IoU_metric_keras(y_true, y_pred):
|
||||
## mean Intersection over Union
|
||||
## Mean IoU = TP/(FN + TP + FP)
|
||||
init = tf.global_variables_initializer()
|
||||
sess = tf.Session()
|
||||
sess.run(init)
|
||||
|
||||
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
|
||||
|
||||
def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
||||
"""
|
||||
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
|
||||
= sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
|
||||
|
||||
The jaccard distance loss is usefull for unbalanced datasets. This has been
|
||||
shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
|
||||
gradient.
|
||||
|
||||
Ref: https://en.wikipedia.org/wiki/Jaccard_index
|
||||
|
||||
@url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
|
||||
@author: wassname
|
||||
"""
|
||||
intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
|
||||
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
|
||||
jac = (intersection + smooth) / (sum_ - intersection + smooth)
|
||||
return (1 - jac) * smooth
|
||||
|
||||
|
@ -0,0 +1,317 @@
|
||||
from keras.models import *
|
||||
from keras.layers import *
|
||||
from keras import layers
|
||||
from keras.regularizers import l2
|
||||
|
||||
resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
IMAGE_ORDERING ='channels_last'
|
||||
MERGE_AXIS=-1
|
||||
|
||||
|
||||
def one_side_pad( x ):
|
||||
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
|
||||
if IMAGE_ORDERING == 'channels_first':
|
||||
x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x)
|
||||
elif IMAGE_ORDERING == 'channels_last':
|
||||
x = Lambda(lambda x : x[: , :-1 , :-1 , : ] )(x)
|
||||
return x
|
||||
|
||||
def identity_block(input_tensor, kernel_size, filters, stage, block):
|
||||
"""The identity block is the block that has no conv layer at shortcut.
|
||||
# Arguments
|
||||
input_tensor: input tensor
|
||||
kernel_size: defualt 3, the kernel size of middle conv layer at main path
|
||||
filters: list of integers, the filterss of 3 conv layer at main path
|
||||
stage: integer, current stage label, used for generating layer names
|
||||
block: 'a','b'..., current block label, used for generating layer names
|
||||
# Returns
|
||||
Output tensor for the block.
|
||||
"""
|
||||
filters1, filters2, filters3 = filters
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||
|
||||
x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(input_tensor)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING ,
|
||||
padding='same', name=conv_name_base + '2b')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
x = Conv2D(filters3 , (1, 1), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
||||
|
||||
x = layers.add([x, input_tensor])
|
||||
x = Activation('relu')(x)
|
||||
return x
|
||||
|
||||
|
||||
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
|
||||
"""conv_block is the block that has a conv layer at shortcut
|
||||
# Arguments
|
||||
input_tensor: input tensor
|
||||
kernel_size: defualt 3, the kernel size of middle conv layer at main path
|
||||
filters: list of integers, the filterss of 3 conv layer at main path
|
||||
stage: integer, current stage label, used for generating layer names
|
||||
block: 'a','b'..., current block label, used for generating layer names
|
||||
# Returns
|
||||
Output tensor for the block.
|
||||
Note that from stage 3, the first conv layer at main path is with strides=(2,2)
|
||||
And the shortcut should have strides=(2,2) as well
|
||||
"""
|
||||
filters1, filters2, filters3 = filters
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||
|
||||
x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , strides=strides,
|
||||
name=conv_name_base + '2a')(input_tensor)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , padding='same',
|
||||
name=conv_name_base + '2b')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
x = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
||||
|
||||
shortcut = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , strides=strides,
|
||||
name=conv_name_base + '1')(input_tensor)
|
||||
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
|
||||
|
||||
x = layers.add([x, shortcut])
|
||||
x = Activation('relu')(x)
|
||||
return x
|
||||
|
||||
|
||||
def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
assert input_height%32 == 0
|
||||
assert input_width%32 == 0
|
||||
|
||||
|
||||
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
|
||||
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x )
|
||||
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
|
||||
if pretraining:
|
||||
model=Model( img_input , x ).load_weights(resnet50_Weights_path)
|
||||
|
||||
|
||||
v512_2048 = Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 )
|
||||
v512_2048 = ( BatchNormalization(axis=bn_axis))(v512_2048)
|
||||
v512_2048 = Activation('relu')(v512_2048)
|
||||
|
||||
|
||||
|
||||
v512_1024=Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f4 )
|
||||
v512_1024 = ( BatchNormalization(axis=bn_axis))(v512_1024)
|
||||
v512_1024 = Activation('relu')(v512_1024)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v512_2048)
|
||||
o = ( concatenate([ o ,v512_1024],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,f2],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,f1],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
|
||||
o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,img_input],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
|
||||
o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
|
||||
o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o )
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
|
||||
model = Model( img_input , o )
|
||||
return model
|
||||
|
||||
def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
assert input_height%32 == 0
|
||||
assert input_width%32 == 0
|
||||
|
||||
|
||||
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
|
||||
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x )
|
||||
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
Model( img_input , x ).load_weights(resnet50_Weights_path)
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 )
|
||||
v1024_2048 = ( BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Activation('relu')(v1024_2048)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = ( concatenate([ o ,f4],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,f2],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o)
|
||||
o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,f1],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
|
||||
o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
|
||||
o = ( concatenate([o,img_input],axis=MERGE_AXIS ) )
|
||||
o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
|
||||
o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o)
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
|
||||
o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o )
|
||||
o = ( BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
model = Model( img_input , o )
|
||||
|
||||
|
||||
|
||||
|
||||
return model
|
@ -0,0 +1,192 @@
|
||||
import os
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
from keras.backend.tensorflow_backend import set_session
|
||||
import keras , warnings
|
||||
from keras.optimizers import *
|
||||
from sacred import Experiment
|
||||
from models import *
|
||||
from utils import *
|
||||
from metrics import *
|
||||
|
||||
|
||||
def configuration():
|
||||
keras.backend.clear_session()
|
||||
tf.reset_default_graph()
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
|
||||
config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
|
||||
|
||||
|
||||
config.gpu_options.allow_growth = True
|
||||
config.gpu_options.per_process_gpu_memory_fraction=0.95#0.95
|
||||
config.gpu_options.visible_device_list="0"
|
||||
set_session(tf.Session(config=config))
|
||||
|
||||
def get_dirs_or_files(input_data):
|
||||
if os.path.isdir(input_data):
|
||||
image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/')
|
||||
# Check if training dir exists
|
||||
assert os.path.isdir(image_input), "{} is not a directory".format(image_input)
|
||||
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
|
||||
return image_input, labels_input
|
||||
|
||||
ex = Experiment()
|
||||
|
||||
@ex.config
|
||||
def config_params():
|
||||
n_classes=None # Number of classes. If your case study is binary case the set it to 2 and otherwise give your number of cases.
|
||||
n_epochs=1
|
||||
input_height=224*1
|
||||
input_width=224*1
|
||||
weight_decay=1e-6 # Weight decay of l2 regularization of model layers.
|
||||
n_batch=1 # Number of batches at each iteration.
|
||||
learning_rate=1e-4
|
||||
patches=False # Make patches of image in order to use all information of image. In the case of page
|
||||
# extraction this should be set to false since model should see all image.
|
||||
augmentation=False
|
||||
flip_aug=False # Flip image (augmentation).
|
||||
elastic_aug=False # Elastic transformation (augmentation).
|
||||
blur_aug=False # Blur patches of image (augmentation).
|
||||
scaling=False # Scaling of patches (augmentation) will be imposed if this set to true.
|
||||
binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied.
|
||||
dir_train=None # Directory of training dataset (sub-folders should be named images and labels).
|
||||
dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels).
|
||||
dir_output=None # Directory of output where the model should be saved.
|
||||
pretraining=False # Set true to load pretrained weights of resnet50 encoder.
|
||||
weighted_loss=False # Set True if classes are unbalanced and you want to use weighted loss function.
|
||||
scaling_bluring=False
|
||||
rotation: False
|
||||
scaling_binarization=False
|
||||
blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation.
|
||||
scales=[0.9 , 1.1 ] # Scale patches with these scales. Used for augmentation.
|
||||
flip_index=[0,1] # Flip image. Used for augmentation.
|
||||
|
||||
|
||||
@ex.automain
|
||||
def run(n_classes,n_epochs,input_height,
|
||||
input_width,weight_decay,weighted_loss,
|
||||
n_batch,patches,augmentation,flip_aug,blur_aug,scaling, binarization,
|
||||
blur_k,scales,dir_train,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
flip_index,dir_eval ,dir_output,pretraining,learning_rate):
|
||||
|
||||
dir_img,dir_seg=get_dirs_or_files(dir_train)
|
||||
dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval)
|
||||
|
||||
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
||||
dir_train_flowing=os.path.join(dir_output,'train')
|
||||
dir_eval_flowing=os.path.join(dir_output,'eval')
|
||||
|
||||
dir_flow_train_imgs=os.path.join(dir_train_flowing,'images')
|
||||
dir_flow_train_labels=os.path.join(dir_train_flowing,'labels')
|
||||
|
||||
dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images')
|
||||
dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels')
|
||||
|
||||
if os.path.isdir(dir_train_flowing):
|
||||
os.system('rm -rf '+dir_train_flowing)
|
||||
os.makedirs(dir_train_flowing)
|
||||
else:
|
||||
os.makedirs(dir_train_flowing)
|
||||
|
||||
if os.path.isdir(dir_eval_flowing):
|
||||
os.system('rm -rf '+dir_eval_flowing)
|
||||
os.makedirs(dir_eval_flowing)
|
||||
else:
|
||||
os.makedirs(dir_eval_flowing)
|
||||
|
||||
|
||||
os.mkdir(dir_flow_train_imgs)
|
||||
os.mkdir(dir_flow_train_labels)
|
||||
|
||||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
|
||||
|
||||
#set the gpu configuration
|
||||
configuration()
|
||||
|
||||
|
||||
#writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
augmentation=augmentation,patches=patches)
|
||||
|
||||
provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs,
|
||||
dir_flow_eval_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
augmentation=False,patches=patches)
|
||||
|
||||
if weighted_loss:
|
||||
weights=np.zeros(n_classes)
|
||||
for obj in os.listdir(dir_seg):
|
||||
label_obj=cv2.imread(dir_seg+'/'+obj)
|
||||
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
|
||||
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
|
||||
|
||||
weights=1.00/weights
|
||||
|
||||
weights=weights/float(np.sum(weights))
|
||||
weights=weights/float(np.min(weights))
|
||||
weights=weights/float(np.sum(weights))
|
||||
|
||||
|
||||
|
||||
|
||||
#get our model.
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
|
||||
if not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
|
||||
mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5',
|
||||
save_weights_only=True, period=1)
|
||||
|
||||
|
||||
#generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch,
|
||||
input_height=input_height, input_width=input_width,n_classes=n_classes )
|
||||
val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch,
|
||||
input_height=input_height, input_width=input_width,n_classes=n_classes )
|
||||
|
||||
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch),
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=n_epochs)
|
||||
|
||||
|
||||
|
||||
os.system('rm -rf '+dir_train_flowing)
|
||||
os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
model.save(dir_output+'/'+'model'+'.h5')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,336 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from scipy.ndimage.interpolation import map_coordinates
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
|
||||
def bluring(img_in,kind):
|
||||
if kind=='guass':
|
||||
img_blur = cv2.GaussianBlur(img_in,(5,5),0)
|
||||
elif kind=="median":
|
||||
img_blur = cv2.medianBlur(img_in,5)
|
||||
elif kind=='blur':
|
||||
img_blur=cv2.blur(img_in,(5,5))
|
||||
return img_blur
|
||||
|
||||
def color_images(seg, n_classes):
|
||||
ann_u=range(n_classes)
|
||||
if len(np.shape(seg))==3:
|
||||
seg=seg[:,:,0]
|
||||
|
||||
seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(float)
|
||||
colors=sns.color_palette("hls", n_classes)
|
||||
|
||||
for c in ann_u:
|
||||
c=int(c)
|
||||
segl=(seg==c)
|
||||
seg_img[:,:,0]+=segl*(colors[c][0])
|
||||
seg_img[:,:,1]+=segl*(colors[c][1])
|
||||
seg_img[:,:,2]+=segl*(colors[c][2])
|
||||
return seg_img
|
||||
|
||||
|
||||
def resize_image(seg_in,input_height,input_width):
|
||||
return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST)
|
||||
def get_one_hot(seg,input_height,input_width,n_classes):
|
||||
seg=seg[:,:,0]
|
||||
seg_f=np.zeros((input_height, input_width,n_classes))
|
||||
for j in range(n_classes):
|
||||
seg_f[:,:,j]=(seg==j).astype(int)
|
||||
return seg_f
|
||||
|
||||
|
||||
def IoU(Yi,y_predi):
|
||||
## mean Intersection over Union
|
||||
## Mean IoU = TP/(FN + TP + FP)
|
||||
|
||||
IoUs = []
|
||||
classes_true=np.unique(Yi)
|
||||
for c in classes_true:
|
||||
TP = np.sum( (Yi == c)&(y_predi==c) )
|
||||
FP = np.sum( (Yi != c)&(y_predi==c) )
|
||||
FN = np.sum( (Yi == c)&(y_predi != c))
|
||||
IoU = TP/float(TP + FP + FN)
|
||||
print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU))
|
||||
IoUs.append(IoU)
|
||||
mIoU = np.mean(IoUs)
|
||||
print("_________________")
|
||||
print("Mean IoU: {:4.3f}".format(mIoU))
|
||||
return mIoU
|
||||
def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_classes):
|
||||
c = 0
|
||||
n = os.listdir(img_folder) #List of training images
|
||||
random.shuffle(n)
|
||||
while True:
|
||||
img = np.zeros((batch_size, input_height, input_width, 3)).astype('float')
|
||||
mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float')
|
||||
|
||||
for i in range(c, c+batch_size): #initially from 0 to 16, c = 0.
|
||||
#print(img_folder+'/'+n[i])
|
||||
filename=n[i].split('.')[0]
|
||||
train_img = cv2.imread(img_folder+'/'+n[i])/255.
|
||||
train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize
|
||||
|
||||
img[i-c] = train_img #add to array - img[0], img[1], and so on.
|
||||
train_mask = cv2.imread(mask_folder+'/'+filename+'.png')
|
||||
#print(mask_folder+'/'+filename+'.png')
|
||||
#print(train_mask.shape)
|
||||
train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes)
|
||||
#train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
||||
|
||||
mask[i-c] = train_mask
|
||||
|
||||
c+=batch_size
|
||||
if(c+batch_size>=len(os.listdir(img_folder))):
|
||||
c=0
|
||||
random.shuffle(n)
|
||||
yield img, mask
|
||||
|
||||
def otsu_copy(img):
|
||||
img_r=np.zeros(img.shape)
|
||||
img1=img[:,:,0]
|
||||
img2=img[:,:,1]
|
||||
img3=img[:,:,2]
|
||||
_, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
_, threshold2 = cv2.threshold(img2, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
_, threshold3 = cv2.threshold(img3, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
img_r[:,:,0]=threshold1
|
||||
img_r[:,:,1]=threshold1
|
||||
img_r[:,:,2]=threshold1
|
||||
return img_r
|
||||
|
||||
def rotation_90(img):
|
||||
img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2]))
|
||||
img_rot[:,:,0]=img[:,:,0].T
|
||||
img_rot[:,:,1]=img[:,:,1].T
|
||||
img_rot[:,:,2]=img[:,:,2].T
|
||||
return img_rot
|
||||
|
||||
def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer):
|
||||
|
||||
|
||||
img_h=img.shape[0]
|
||||
img_w=img.shape[1]
|
||||
|
||||
nxf=img_w/float(width)
|
||||
nyf=img_h/float(height)
|
||||
|
||||
if nxf>int(nxf):
|
||||
nxf=int(nxf)+1
|
||||
if nyf>int(nyf):
|
||||
nyf=int(nyf)+1
|
||||
|
||||
nxf=int(nxf)
|
||||
nyf=int(nyf)
|
||||
|
||||
for i in range(nxf):
|
||||
for j in range(nyf):
|
||||
index_x_d=i*width
|
||||
index_x_u=(i+1)*width
|
||||
|
||||
index_y_d=j*height
|
||||
index_y_u=(j+1)*height
|
||||
|
||||
if index_x_u>img_w:
|
||||
index_x_u=img_w
|
||||
index_x_d=img_w-width
|
||||
if index_y_u>img_h:
|
||||
index_y_u=img_h
|
||||
index_y_d=img_h-height
|
||||
|
||||
|
||||
img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:]
|
||||
label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:]
|
||||
|
||||
cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch )
|
||||
cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch )
|
||||
indexer+=1
|
||||
return indexer
|
||||
|
||||
|
||||
|
||||
def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,scaler):
|
||||
|
||||
|
||||
img_h=img.shape[0]
|
||||
img_w=img.shape[1]
|
||||
|
||||
height_scale=int(height*scaler)
|
||||
width_scale=int(width*scaler)
|
||||
|
||||
|
||||
nxf=img_w/float(width_scale)
|
||||
nyf=img_h/float(height_scale)
|
||||
|
||||
if nxf>int(nxf):
|
||||
nxf=int(nxf)+1
|
||||
if nyf>int(nyf):
|
||||
nyf=int(nyf)+1
|
||||
|
||||
nxf=int(nxf)
|
||||
nyf=int(nyf)
|
||||
|
||||
for i in range(nxf):
|
||||
for j in range(nyf):
|
||||
index_x_d=i*width_scale
|
||||
index_x_u=(i+1)*width_scale
|
||||
|
||||
index_y_d=j*height_scale
|
||||
index_y_u=(j+1)*height_scale
|
||||
|
||||
if index_x_u>img_w:
|
||||
index_x_u=img_w
|
||||
index_x_d=img_w-width_scale
|
||||
if index_y_u>img_h:
|
||||
index_y_u=img_h
|
||||
index_y_d=img_h-height_scale
|
||||
|
||||
|
||||
img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:]
|
||||
label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:]
|
||||
|
||||
img_patch=resize_image(img_patch,height,width)
|
||||
label_patch=resize_image(label_patch,height,width)
|
||||
|
||||
cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch )
|
||||
cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch )
|
||||
indexer+=1
|
||||
|
||||
return indexer
|
||||
|
||||
|
||||
|
||||
def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
augmentation=False,patches=False):
|
||||
|
||||
imgs_cv_train=np.array(os.listdir(dir_img))
|
||||
segs_cv_train=np.array(os.listdir(dir_seg))
|
||||
|
||||
indexer=0
|
||||
for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)):
|
||||
img_name=im.split('.')[0]
|
||||
|
||||
if not patches:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) )
|
||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) )
|
||||
indexer+=1
|
||||
|
||||
if augmentation:
|
||||
if rotation:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png',
|
||||
rotation_90( resize_image(cv2.imread(dir_img+'/'+im),
|
||||
input_height,input_width) ) )
|
||||
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png',
|
||||
rotation_90 ( resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width) ) )
|
||||
indexer+=1
|
||||
|
||||
if flip_aug:
|
||||
for f_i in flip_index:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png',
|
||||
resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) )
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' ,
|
||||
resize_image(cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png'),f_i),input_height,input_width) )
|
||||
indexer+=1
|
||||
|
||||
if blur_aug:
|
||||
for blur_i in blur_k:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png',
|
||||
(resize_image(bluring(cv2.imread(dir_img+'/'+im),blur_i),input_height,input_width) ) )
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' ,
|
||||
resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width) )
|
||||
indexer+=1
|
||||
|
||||
|
||||
if binarization:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png',
|
||||
resize_image(otsu_copy( cv2.imread(dir_img+'/'+im)),input_height,input_width ))
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png',
|
||||
resize_image( cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ))
|
||||
indexer+=1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if patches:
|
||||
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer)
|
||||
|
||||
if augmentation:
|
||||
|
||||
if rotation:
|
||||
|
||||
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
rotation_90( cv2.imread(dir_img+'/'+im) ),
|
||||
rotation_90( cv2.imread(dir_seg+'/'+img_name+'.png') ),
|
||||
input_height,input_width,indexer=indexer)
|
||||
if flip_aug:
|
||||
for f_i in flip_index:
|
||||
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
cv2.flip( cv2.imread(dir_img+'/'+im) , f_i),
|
||||
cv2.flip( cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i),
|
||||
input_height,input_width,indexer=indexer)
|
||||
if blur_aug:
|
||||
for blur_i in blur_k:
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
bluring( cv2.imread(dir_img+'/'+im) , blur_i),
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer)
|
||||
|
||||
|
||||
if scaling:
|
||||
for sc_ind in scales:
|
||||
indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
cv2.imread(dir_img+'/'+im) ,
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
if binarization:
|
||||
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
otsu_copy( cv2.imread(dir_img+'/'+im)),
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer)
|
||||
|
||||
|
||||
|
||||
if scaling_bluring:
|
||||
for sc_ind in scales:
|
||||
for blur_i in blur_k:
|
||||
indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
bluring( cv2.imread(dir_img+'/'+im) , blur_i) ,
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png') ,
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
|
||||
if scaling_binarization:
|
||||
for sc_ind in scales:
|
||||
indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
otsu_copy( cv2.imread(dir_img+'/'+im)) ,
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue