diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..c63cc22 --- /dev/null +++ b/metrics.py @@ -0,0 +1,338 @@ +from keras import backend as K +import tensorflow as tf +import numpy as np + +def focal_loss(gamma=2., alpha=4.): + + gamma = float(gamma) + alpha = float(alpha) + + def focal_loss_fixed(y_true, y_pred): + """Focal loss for multi-classification + FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t) + Notice: y_pred is probability after softmax + gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper + d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x) + Focal Loss for Dense Object Detection + https://arxiv.org/abs/1708.02002 + + Arguments: + y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls] + y_pred {tensor} -- model's output, shape of [batch_size, num_cls] + + Keyword Arguments: + gamma {float} -- (default: {2.0}) + alpha {float} -- (default: {4.0}) + + Returns: + [tensor] -- loss. + """ + epsilon = 1.e-9 + y_true = tf.convert_to_tensor(y_true, tf.float32) + y_pred = tf.convert_to_tensor(y_pred, tf.float32) + + model_out = tf.add(y_pred, epsilon) + ce = tf.multiply(y_true, -tf.log(model_out)) + weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma)) + fl = tf.multiply(alpha, tf.multiply(weight, ce)) + reduced_fl = tf.reduce_max(fl, axis=1) + return tf.reduce_mean(reduced_fl) + return focal_loss_fixed + +def weighted_categorical_crossentropy(weights=None): + """ weighted_categorical_crossentropy + + Args: + * weights: crossentropy weights + Returns: + * weighted categorical crossentropy function + """ + + def loss(y_true, y_pred): + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum(tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + return tf.reduce_mean(per_pixel_loss) + return loss +def image_categorical_cross_entropy(y_true, y_pred, weights=None): + """ + :param y_true: tensor of shape (batch_size, height, width) representing the ground truth. + :param y_pred: tensor of shape (batch_size, height, width) representing the prediction. + :return: The mean cross-entropy on softmaxed tensors. + """ + + labels_floats = tf.cast(y_true, tf.float32) + per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + + if weights is not None: + weight_mask = tf.maximum( + tf.reduce_max(tf.constant( + np.array(weights, dtype=np.float32)[None, None, None]) + * labels_floats, axis=-1), 1.0) + per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + + return tf.reduce_mean(per_pixel_loss) +def class_tversky(y_true, y_pred): + smooth = 1.0#1.00 + + y_true = K.permute_dimensions(y_true, (3,1,2,0)) + y_pred = K.permute_dimensions(y_pred, (3,1,2,0)) + + y_true_pos = K.batch_flatten(y_true) + y_pred_pos = K.batch_flatten(y_pred) + true_pos = K.sum(y_true_pos * y_pred_pos, 1) + false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1) + false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1) + alpha = 0.2#0.5 + beta=0.8 + return (true_pos + smooth)/(true_pos + alpha*false_neg + (beta)*false_pos + smooth) + +def focal_tversky_loss(y_true,y_pred): + pt_1 = class_tversky(y_true, y_pred) + gamma =1.3#4./3.0#1.3#4.0/3.00# 0.75 + return K.sum(K.pow((1-pt_1), gamma)) + +def generalized_dice_coeff2(y_true, y_pred): + n_el = 1 + for dim in y_true.shape: + n_el *= int(dim) + n_cl = y_true.shape[-1] + w = K.zeros(shape=(n_cl,)) + w = (K.sum(y_true, axis=(0,1,2)))/(n_el) + w = 1/(w**2+0.000001) + numerator = y_true*y_pred + numerator = w*K.sum(numerator,(0,1,2)) + numerator = K.sum(numerator) + denominator = y_true+y_pred + denominator = w*K.sum(denominator,(0,1,2)) + denominator = K.sum(denominator) + return 2*numerator/denominator +def generalized_dice_coeff(y_true, y_pred): + axes = tuple(range(1, len(y_pred.shape)-1)) + Ncl = y_pred.shape[-1] + w = K.zeros(shape=(Ncl,)) + w = K.sum(y_true, axis=axes) + w = 1/(w**2+0.000001) + # Compute gen dice coef: + numerator = y_true*y_pred + numerator = w*K.sum(numerator,axes) + numerator = K.sum(numerator) + + denominator = y_true+y_pred + denominator = w*K.sum(denominator,axes) + denominator = K.sum(denominator) + + gen_dice_coef = 2*numerator/denominator + + return gen_dice_coef + +def generalized_dice_loss(y_true, y_pred): + return 1 - generalized_dice_coeff2(y_true, y_pred) +def soft_dice_loss(y_true, y_pred, epsilon=1e-6): + ''' + Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. + Assumes the `channels_last` format. + + # Arguments + y_true: b x X x Y( x Z...) x c One hot encoding of ground truth + y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) + epsilon: Used for numerical stability to avoid divide by zero errors + + # References + V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation + https://arxiv.org/abs/1606.04797 + More details on Dice loss formulation + https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72) + + Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022 + ''' + + # skip the batch and class axis for calculating Dice score + axes = tuple(range(1, len(y_pred.shape)-1)) + + numerator = 2. * K.sum(y_pred * y_true, axes) + + denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) + return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch + +def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last = True, mean_per_class=False, verbose=False): + """ + Compute mean metrics of two segmentation masks, via Keras. + + IoU(A,B) = |A & B| / (| A U B|) + Dice(A,B) = 2*|A & B| / (|A| + |B|) + + Args: + y_true: true masks, one-hot encoded. + y_pred: predicted masks, either softmax outputs, or one-hot encoded. + metric_name: metric to be computed, either 'iou' or 'dice'. + metric_type: one of 'standard' (default), 'soft', 'naive'. + In the standard version, y_pred is one-hot encoded and the mean + is taken only over classes that are present (in y_true or y_pred). + The 'soft' version of the metrics are computed without one-hot + encoding y_pred. + The 'naive' version return mean metrics where absent classes contribute + to the class mean as 1.0 (instead of being dropped from the mean). + drop_last = True: boolean flag to drop last class (usually reserved + for background class in semantic segmentation) + mean_per_class = False: return mean along batch axis for each class. + verbose = False: print intermediate results such as intersection, union + (as number of pixels). + Returns: + IoU/Dice of y_true and y_pred, as a float, unless mean_per_class == True + in which case it returns the per-class metric, averaged over the batch. + + Inputs are B*W*H*N tensors, with + B = batch size, + W = width, + H = height, + N = number of classes + """ + + flag_soft = (metric_type == 'soft') + flag_naive_mean = (metric_type == 'naive') + + # always assume one or more classes + num_classes = K.shape(y_true)[-1] + + if not flag_soft: + # get one-hot encoded masks from y_pred (true masks should already be one-hot) + y_pred = K.one_hot(K.argmax(y_pred), num_classes) + y_true = K.one_hot(K.argmax(y_true), num_classes) + + # if already one-hot, could have skipped above command + # keras uses float32 instead of float64, would give error down (but numpy arrays or keras.to_categorical gives float64) + y_true = K.cast(y_true, 'float32') + y_pred = K.cast(y_pred, 'float32') + + # intersection and union shapes are batch_size * n_classes (values = area in pixels) + axes = (1,2) # W,H axes of each image + intersection = K.sum(K.abs(y_true * y_pred), axis=axes) + mask_sum = K.sum(K.abs(y_true), axis=axes) + K.sum(K.abs(y_pred), axis=axes) + union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot + + smooth = .001 + iou = (intersection + smooth) / (union + smooth) + dice = 2 * (intersection + smooth)/(mask_sum + smooth) + + metric = {'iou': iou, 'dice': dice}[metric_name] + + # define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise + mask = K.cast(K.not_equal(union, 0), 'float32') + + if drop_last: + metric = metric[:,:-1] + mask = mask[:,:-1] + + if verbose: + print('intersection, union') + print(K.eval(intersection), K.eval(union)) + print(K.eval(intersection/union)) + + # return mean metrics: remaining axes are (batch, classes) + if flag_naive_mean: + return K.mean(metric) + + # take mean only over non-absent classes + class_count = K.sum(mask, axis=0) + non_zero = tf.greater(class_count, 0) + non_zero_sum = tf.boolean_mask(K.sum(metric * mask, axis=0), non_zero) + non_zero_count = tf.boolean_mask(class_count, non_zero) + + if verbose: + print('Counts of inputs with class present, metrics for non-absent classes') + print(K.eval(class_count), K.eval(non_zero_sum / non_zero_count)) + + return K.mean(non_zero_sum / non_zero_count) + +def mean_iou(y_true, y_pred, **kwargs): + """ + Compute mean Intersection over Union of two segmentation masks, via Keras. + + Calls metrics_k(y_true, y_pred, metric_name='iou'), see there for allowed kwargs. + """ + return seg_metrics(y_true, y_pred, metric_name='iou', **kwargs) +def Mean_IOU(y_true, y_pred): + nb_classes = K.int_shape(y_pred)[-1] + iou = [] + true_pixels = K.argmax(y_true, axis=-1) + pred_pixels = K.argmax(y_pred, axis=-1) + void_labels = K.equal(K.sum(y_true, axis=-1), 0) + for i in range(0, nb_classes): # exclude first label (background) and last label (void) + true_labels = K.equal(true_pixels, i)# & ~void_labels + pred_labels = K.equal(pred_pixels, i)# & ~void_labels + inter = tf.to_int32(true_labels & pred_labels) + union = tf.to_int32(true_labels | pred_labels) + legal_batches = K.sum(tf.to_int32(true_labels), axis=1)>0 + ious = K.sum(inter, axis=1)/K.sum(union, axis=1) + iou.append(K.mean(tf.gather(ious, indices=tf.where(legal_batches)))) # returns average IoU of the same objects + iou = tf.stack(iou) + legal_labels = ~tf.debugging.is_nan(iou) + iou = tf.gather(iou, indices=tf.where(legal_labels)) + return K.mean(iou) + +def iou_vahid(y_true, y_pred): + nb_classes = tf.shape(y_true)[-1]+tf.to_int32(1) + true_pixels = K.argmax(y_true, axis=-1) + pred_pixels = K.argmax(y_pred, axis=-1) + iou = [] + + for i in tf.range(nb_classes): + tp=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.equal(pred_pixels, i) ) ) + fp=K.sum( tf.to_int32( K.not_equal(true_pixels, i) & K.equal(pred_pixels, i) ) ) + fn=K.sum( tf.to_int32( K.equal(true_pixels, i) & K.not_equal(pred_pixels, i) ) ) + iouh=tp/(tp+fp+fn) + iou.append(iouh) + return K.mean(iou) + + +def IoU_metric(Yi,y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + y_predi = np.argmax(y_predi, axis=3) + y_testi = np.argmax(Yi, axis=3) + IoUs = [] + Nclass = int(np.max(Yi)) + 1 + for c in range(Nclass): + TP = np.sum( (Yi == c)&(y_predi==c) ) + FP = np.sum( (Yi != c)&(y_predi==c) ) + FN = np.sum( (Yi == c)&(y_predi != c)) + IoU = TP/float(TP + FP + FN) + IoUs.append(IoU) + return K.cast( np.mean(IoUs) ,dtype='float32' ) + + +def IoU_metric_keras(y_true, y_pred): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + init = tf.global_variables_initializer() + sess = tf.Session() + sess.run(init) + + return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess)) + +def jaccard_distance_loss(y_true, y_pred, smooth=100): + """ + Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) + = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) + + The jaccard distance loss is usefull for unbalanced datasets. This has been + shifted so it converges on 0 and is smoothed to avoid exploding or disapearing + gradient. + + Ref: https://en.wikipedia.org/wiki/Jaccard_index + + @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96 + @author: wassname + """ + intersection = K.sum(K.abs(y_true * y_pred), axis=-1) + sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) + jac = (intersection + smooth) / (sum_ - intersection + smooth) + return (1 - jac) * smooth + + diff --git a/models.py b/models.py new file mode 100644 index 0000000..7c806b4 --- /dev/null +++ b/models.py @@ -0,0 +1,317 @@ +from keras.models import * +from keras.layers import * +from keras import layers +from keras.regularizers import l2 + +resnet50_Weights_path='./pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' +IMAGE_ORDERING ='channels_last' +MERGE_AXIS=-1 + + +def one_side_pad( x ): + x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) + if IMAGE_ORDERING == 'channels_first': + x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x) + elif IMAGE_ORDERING == 'channels_last': + x = Lambda(lambda x : x[: , :-1 , :-1 , : ] )(x) + return x + +def identity_block(input_tensor, kernel_size, filters, stage, block): + """The identity block is the block that has no conv layer at shortcut. + # Arguments + input_tensor: input tensor + kernel_size: defualt 3, the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + # Returns + Output tensor for the block. + """ + filters1, filters2, filters3 = filters + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(input_tensor) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) + x = Activation('relu')(x) + + x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , + padding='same', name=conv_name_base + '2b')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) + x = Activation('relu')(x) + + x = Conv2D(filters3 , (1, 1), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) + + x = layers.add([x, input_tensor]) + x = Activation('relu')(x) + return x + + +def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): + """conv_block is the block that has a conv layer at shortcut + # Arguments + input_tensor: input tensor + kernel_size: defualt 3, the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + # Returns + Output tensor for the block. + Note that from stage 3, the first conv layer at main path is with strides=(2,2) + And the shortcut should have strides=(2,2) as well + """ + filters1, filters2, filters3 = filters + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + + x = Conv2D(filters1, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, + name=conv_name_base + '2a')(input_tensor) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) + x = Activation('relu')(x) + + x = Conv2D(filters2, kernel_size , data_format=IMAGE_ORDERING , padding='same', + name=conv_name_base + '2b')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) + x = Activation('relu')(x) + + x = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x) + x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) + + shortcut = Conv2D(filters3, (1, 1) , data_format=IMAGE_ORDERING , strides=strides, + name=conv_name_base + '1')(input_tensor) + shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) + + x = layers.add([x, shortcut]) + x = Activation('relu')(x) + return x + + +def resnet50_unet_light(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + assert input_height%32 == 0 + assert input_width%32 == 0 + + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) + + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x ) + + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + + if pretraining: + model=Model( img_input , x ).load_weights(resnet50_Weights_path) + + + v512_2048 = Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) + v512_2048 = ( BatchNormalization(axis=bn_axis))(v512_2048) + v512_2048 = Activation('relu')(v512_2048) + + + + v512_1024=Conv2D( 512 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f4 ) + v512_1024 = ( BatchNormalization(axis=bn_axis))(v512_1024) + v512_1024 = Activation('relu')(v512_1024) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v512_2048) + o = ( concatenate([ o ,v512_1024],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + + o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) + o = ( BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + + + model = Model( img_input , o ) + return model + +def resnet50_unet(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + assert input_height%32 == 0 + assert input_width%32 == 0 + + + img_input = Input(shape=(input_height,input_width , 3 )) + + if IMAGE_ORDERING == 'channels_last': + bn_axis = 3 + else: + bn_axis = 1 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x) + + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x ) + + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + Model( img_input , x ).load_weights(resnet50_Weights_path) + + v1024_2048 = Conv2D( 1024 , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( f5 ) + v1024_2048 = ( BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(v1024_2048) + o = ( concatenate([ o ,f4],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([ o ,f3],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f2],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING))(o) + o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay) ) )(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,f1],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 64 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o) + o = ( concatenate([o,img_input],axis=MERGE_AXIS ) ) + o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o) + o = ( Conv2D( 32 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) ))(o) + o = ( BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + + o = Conv2D( n_classes , (1, 1) , padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay) )( o ) + o = ( BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + + model = Model( img_input , o ) + + + + + return model