first working update of branch

pull/18/head
vahidrezanezhad 8 months ago
parent 02b1436f39
commit d27647a0f1

@ -1,8 +1,9 @@
{ {
"n_classes" : 3, "model_name" : "hybrid_transformer_cnn",
"n_classes" : 2,
"n_epochs" : 2, "n_epochs" : 2,
"input_height" : 448, "input_height" : 448,
"input_width" : 672, "input_width" : 448,
"weight_decay" : 1e-6, "weight_decay" : 1e-6,
"n_batch" : 2, "n_batch" : 2,
"learning_rate": 1e-4, "learning_rate": 1e-4,
@ -18,13 +19,21 @@
"scaling_flip" : false, "scaling_flip" : false,
"rotation": false, "rotation": false,
"rotation_not_90": false, "rotation_not_90": false,
"num_patches_xy": [28, 28],
"transformer_patchsize": 1,
"blur_k" : ["blur","guass","median"],
"scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
"brightness" : [1.3, 1.5, 1.7, 2],
"degrade_scales" : [0.2, 0.4],
"flip_index" : [0, 1, -1],
"thetha" : [10, -10],
"continue_training": false, "continue_training": false,
"index_start": 0, "index_start" : 0,
"dir_of_start_model": " ", "dir_of_start_model" : " ",
"weighted_loss": false, "weighted_loss": false,
"is_loss_soft_dice": false, "is_loss_soft_dice": false,
"data_is_provided": false, "data_is_provided": false,
"dir_train": "/train", "dir_train": "/train",
"dir_eval": "/eval", "dir_eval": "/eval",
"dir_output": "/output" "dir_output": "/out"
} }

@ -1,12 +1,80 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import * from tensorflow.keras.models import *
from tensorflow.keras.layers import * from tensorflow.keras.layers import *
from tensorflow.keras import layers from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2 from tensorflow.keras.regularizers import l2
mlp_head_units = [2048, 1024]
projection_dim = 64
transformer_layers = 8
num_heads = 4
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
IMAGE_ORDERING = 'channels_last' IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1 MERGE_AXIS = -1
transformer_units = [
projection_dim * 2,
projection_dim,
] # Size of the transformer layers
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
class Patches(layers.Layer):
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
print(tf.shape(images)[1],'images')
print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
print(patches.shape,patch_dims,'patch_dims')
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def one_side_pad(x): def one_side_pad(x):
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
@ -292,3 +360,114 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-
model = Model(img_input, o) model = Model(img_input, o)
return model return model
def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
inputs = layers.Input(shape=(input_height, input_width, 3))
IMAGE_ORDERING = 'channels_last'
bn_axis=3
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
f1 = x
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
f2 = one_side_pad(x)
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
f3 = x
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
f4 = x
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
f5 = x
if pretraining:
model = keras.Model(inputs, x).load_weights(resnet50_Weights_path)
num_patches = x.shape[1]*x.shape[2]
patches = Patches(patch_size)(x)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64])
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
v1024_2048 = Activation('relu')(v1024_2048)
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
o = (concatenate([o, f4],axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o ,f3], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f2], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f1], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, inputs],axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = Activation('relu')(o)
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
o = (BatchNormalization(axis=bn_axis))(o)
o = (Activation('softmax'))(o)
model = keras.Model(inputs=inputs, outputs=o)
return model

@ -10,6 +10,7 @@ from utils import *
from metrics import * from metrics import *
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from tqdm import tqdm from tqdm import tqdm
import json
def configuration(): def configuration():
@ -42,9 +43,13 @@ def config_params():
learning_rate = 1e-4 # Set the learning rate. learning_rate = 1e-4 # Set the learning rate.
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false. patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
augmentation = False # To apply any kind of augmentation, this parameter must be set to true. augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py. flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json.
blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py. blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json.
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py. padding_white = False # If true, white padding will be applied to the image.
padding_black = False # If true, black padding will be applied to the image.
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json.
degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json.
brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json.
binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images.
dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels".
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
@ -52,13 +57,18 @@ def config_params():
pretraining = False # Set to true to load pretrained weights of ResNet50 encoder. pretraining = False # Set to true to load pretrained weights of ResNet50 encoder.
scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image. scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image.
scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image. scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image.
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
thetha = [10, -10] # Rotate image by these angles for augmentation. thetha = None # Rotate image by these angles for augmentation.
blur_k = ['blur', 'gauss', 'median'] # Blur image for augmentation. blur_k = None # Blur image for augmentation.
scales = [0.5, 2] # Scale patches for augmentation. scales = None # Scale patches for augmentation.
flip_index = [0, 1, -1] # Flip image for augmentation. degrade_scales = None # Degrade image for augmentation.
brightness = None # Brighten image for augmentation.
flip_index = None # Flip image for augmentation.
continue_training = False # Set to true if you would like to continue training an already trained a model. continue_training = False # Set to true if you would like to continue training an already trained a model.
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. transformer_patchsize = None # Patch size of vision transformer patches.
num_patches_xy = None # Number of patches for vision transformer.
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
@ -66,15 +76,19 @@ def config_params():
@ex.automain @ex.automain
def run(n_classes, n_epochs, input_height, def run(_config, n_classes, n_epochs, input_height,
input_width, weight_decay, weighted_loss, input_width, weight_decay, weighted_loss,
index_start, dir_of_start_model, is_loss_soft_dice, index_start, dir_of_start_model, is_loss_soft_dice,
n_batch, patches, augmentation, flip_aug, n_batch, patches, augmentation, flip_aug,
blur_aug, scaling, binarization, blur_aug, padding_white, padding_black, scaling, degrading,
blur_k, scales, dir_train, data_is_provided, brightening, binarization, blur_k, scales, degrade_scales,
scaling_bluring, scaling_binarization, rotation, brightness, dir_train, data_is_provided, scaling_bluring,
rotation_not_90, thetha, scaling_flip, continue_training, scaling_brightness, scaling_binarization, rotation, rotation_not_90,
flip_index, dir_eval, dir_output, pretraining, learning_rate): thetha, scaling_flip, continue_training, transformer_patchsize,
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
pretraining, learning_rate):
num_patches = num_patches_xy[0]*num_patches_xy[1]
if data_is_provided: if data_is_provided:
dir_train_flowing = os.path.join(dir_output, 'train') dir_train_flowing = os.path.join(dir_output, 'train')
dir_eval_flowing = os.path.join(dir_output, 'eval') dir_eval_flowing = os.path.join(dir_output, 'eval')
@ -122,22 +136,27 @@ def run(n_classes, n_epochs, input_height,
# set the gpu configuration # set the gpu configuration
configuration() configuration()
imgs_list=np.array(os.listdir(dir_img))
segs_list=np.array(os.listdir(dir_seg))
imgs_list_test=np.array(os.listdir(dir_img_val))
segs_list_test=np.array(os.listdir(dir_seg_val))
# writing patches into a sub-folder in order to be flowed from directory. # writing patches into a sub-folder in order to be flowed from directory.
provide_patches(dir_img, dir_seg, dir_flow_train_imgs, provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs,
dir_flow_train_labels, dir_flow_train_labels, input_height, input_width, blur_k,
input_height, input_width, blur_k, blur_aug, blur_aug, padding_white, padding_black, flip_aug, binarization,
flip_aug, binarization, scaling, scales, flip_index, scaling, degrading, brightening, scales, degrade_scales, brightness,
scaling_bluring, scaling_binarization, rotation, flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
rotation_not_90, thetha, scaling_flip, rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation,
augmentation=augmentation, patches=patches) patches=patches)
provide_patches(dir_img_val, dir_seg_val, dir_flow_eval_imgs, provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
dir_flow_eval_labels, dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width,
input_height, input_width, blur_k, blur_aug, blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
flip_aug, binarization, scaling, scales, flip_index, scaling, degrading, brightening, scales, degrade_scales, brightness,
scaling_bluring, scaling_binarization, rotation, flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
rotation_not_90, thetha, scaling_flip, rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches)
augmentation=False, patches=patches)
if weighted_loss: if weighted_loss:
weights = np.zeros(n_classes) weights = np.zeros(n_classes)
@ -166,20 +185,30 @@ def run(n_classes, n_epochs, input_height,
weights = weights / float(np.sum(weights)) weights = weights / float(np.sum(weights))
if continue_training: if continue_training:
if is_loss_soft_dice: if model_name=='resnet50_unet':
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) if is_loss_soft_dice:
if weighted_loss: model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
model = load_model(dir_of_start_model, compile=True, if weighted_loss:
custom_objects={'loss': weighted_categorical_crossentropy(weights)}) model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss: if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model, compile=True) model = load_model(dir_of_start_model , compile=True)
elif model_name=='hybrid_transformer_cnn':
if is_loss_soft_dice:
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
if weighted_loss:
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
else: else:
# get our model.
index_start = 0 index_start = 0
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining) if model_name=='resnet50_unet':
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
elif model_name=='hybrid_transformer_cnn':
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining)
#if you want to see the model structure just uncomment model summary.
#model.summary()
# if you want to see the model structure just uncomment model summary.
# model.summary()
if not is_loss_soft_dice and not weighted_loss: if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
@ -187,7 +216,6 @@ def run(n_classes, n_epochs, input_height,
if is_loss_soft_dice: if is_loss_soft_dice:
model.compile(loss=soft_dice_loss, model.compile(loss=soft_dice_loss,
optimizer=Adam(lr=learning_rate), metrics=['accuracy']) optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
if weighted_loss: if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights), model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=Adam(lr=learning_rate), metrics=['accuracy']) optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
@ -198,6 +226,9 @@ def run(n_classes, n_epochs, input_height,
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
input_height=input_height, input_width=input_width, n_classes=n_classes) input_height=input_height, input_width=input_width, n_classes=n_classes)
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
##score_best=[]
##score_best.append(0)
for i in tqdm(range(index_start, n_epochs + index_start)): for i in tqdm(range(index_start, n_epochs + index_start)):
model.fit_generator( model.fit_generator(
train_gen, train_gen,
@ -205,9 +236,12 @@ def run(n_classes, n_epochs, input_height,
validation_data=val_gen, validation_data=val_gen,
validation_steps=1, validation_steps=1,
epochs=1) epochs=1)
model.save(dir_output + '/' + 'model_' + str(i)) model.save(dir_output+'/'+'model_'+str(i))
with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp:
json.dump(_config, fp) # encode dict into JSON
# os.system('rm -rf '+dir_train_flowing) #os.system('rm -rf '+dir_train_flowing)
# os.system('rm -rf '+dir_eval_flowing) #os.system('rm -rf '+dir_eval_flowing)
# model.save(dir_output+'/'+'model'+'.h5') #model.save(dir_output+'/'+'model'+'.h5')

@ -9,6 +9,15 @@ from tqdm import tqdm
import imutils import imutils
import math import math
def do_brightening(img_in_dir, factor):
im = Image.open(img_in_dir)
enhancer = ImageEnhance.Brightness(im)
out_img = enhancer.enhance(factor)
out_img = out_img.convert('RGB')
opencv_img = np.array(out_img)
opencv_img = opencv_img[:,:,::-1].copy()
return opencv_img
def bluring(img_in, kind): def bluring(img_in, kind):
if kind == 'gauss': if kind == 'gauss':
@ -138,11 +147,11 @@ def IoU(Yi, y_predi):
FP = np.sum((Yi != c) & (y_predi == c)) FP = np.sum((Yi != c) & (y_predi == c))
FN = np.sum((Yi == c) & (y_predi != c)) FN = np.sum((Yi == c) & (y_predi != c))
IoU = TP / float(TP + FP + FN) IoU = TP / float(TP + FP + FN)
print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU)) #print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU))
IoUs.append(IoU) IoUs.append(IoU)
mIoU = np.mean(IoUs) mIoU = np.mean(IoUs)
print("_________________") #print("_________________")
print("Mean IoU: {:4.3f}".format(mIoU)) #print("Mean IoU: {:4.3f}".format(mIoU))
return mIoU return mIoU
@ -241,9 +250,56 @@ def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer):
return indexer return indexer
def do_padding_white(img):
img_org_h = img.shape[0]
img_org_w = img.shape[1]
index_start_h = 4
index_start_w = 4
img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1]+ 2*index_start_w, img.shape[2])) + 255
img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
return img_padded.astype(float)
def do_degrading(img, scale):
img_org_h = img.shape[0]
img_org_w = img.shape[1]
img_res = resize_image(img, int(img_org_h * scale), int(img_org_w * scale))
return resize_image(img_res, img_org_h, img_org_w)
def do_padding_black(img):
img_org_h = img.shape[0]
img_org_w = img.shape[1]
index_start_h = 4
index_start_w = 4
img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2]))
img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
return img_padded.astype(float)
def do_padding_label(img):
img_org_h = img.shape[0]
img_org_w = img.shape[1]
index_start_h = 4
index_start_w = 4
img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2]))
img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
return img_padded.astype(np.int16)
def do_padding(img, label, height, width): def do_padding(img, label, height, width):
height_new = img.shape[0] height_new=img.shape[0]
width_new = img.shape[1] width_new=img.shape[1]
h_start = 0 h_start = 0
w_start = 0 w_start = 0
@ -262,7 +318,7 @@ def do_padding(img, label, height, width):
img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :]) img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :])
label_new[h_start:h_start + label.shape[0], w_start:w_start + label.shape[1], :] = np.copy(label[:, :, :]) label_new[h_start:h_start + label.shape[0], w_start:w_start + label.shape[1], :] = np.copy(label[:, :, :])
return img_new, label_new return img_new,label_new
def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler): def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler):
@ -275,6 +331,7 @@ def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, index
height_scale = int(height * scaler) height_scale = int(height * scaler)
width_scale = int(width * scaler) width_scale = int(width * scaler)
nxf = img_w / float(width_scale) nxf = img_w / float(width_scale)
nyf = img_h / float(height_scale) nyf = img_h / float(height_scale)
@ -301,6 +358,7 @@ def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, index
index_y_u = img_h index_y_u = img_h
index_y_d = img_h - height_scale index_y_d = img_h - height_scale
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
@ -356,9 +414,6 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
# img_patch=resize_image(img_patch,height,width)
# label_patch=resize_image(label_patch,height,width)
cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch)
cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch)
indexer += 1 indexer += 1
@ -366,47 +421,38 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i
return indexer return indexer
def provide_patches(dir_img, dir_seg, dir_flow_train_imgs, def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs,
dir_flow_train_labels, dir_flow_train_labels, input_height, input_width, blur_k, blur_aug,
input_height, input_width, blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, scaling, degrading,
flip_aug, binarization, scaling, scales, flip_index, brightening, scales, degrade_scales, brightness, flip_index,
scaling_bluring, scaling_binarization, rotation, scaling_bluring, scaling_brightness, scaling_binarization, rotation,
rotation_not_90, thetha, scaling_flip, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=False):
augmentation=False, patches=False):
imgs_cv_train = np.array(os.listdir(dir_img))
segs_cv_train = np.array(os.listdir(dir_seg))
indexer = 0 indexer = 0
for im, seg_i in tqdm(zip(imgs_cv_train, segs_cv_train)): for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)):
img_name = im.split('.')[0] img_name = im.split('.')[0]
if not patches: if not patches:
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width))
resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
indexer += 1 indexer += 1
if augmentation: if augmentation:
if flip_aug: if flip_aug:
for f_i in flip_index: for f_i in flip_index:
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
resize_image(cv2.flip(cv2.imread(dir_img + '/' + im), f_i), input_height, resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) )
input_width))
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width))
input_height, input_width))
indexer += 1 indexer += 1
if blur_aug: if blur_aug:
for blur_i in blur_k: for blur_i in blur_k:
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
(resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width)))
input_width)))
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
input_width))
indexer += 1 indexer += 1
if binarization: if binarization:
@ -417,27 +463,23 @@ def provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
indexer += 1 indexer += 1
if patches:
if patches:
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'), cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'),
input_height, input_width, indexer=indexer) input_height, input_width, indexer=indexer)
if augmentation: if augmentation:
if rotation: if rotation:
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
rotation_90(cv2.imread(dir_img + '/' + im)), rotation_90(cv2.imread(dir_img + '/' + im)),
rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')), rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')),
input_height, input_width, indexer=indexer) input_height, input_width, indexer=indexer)
if rotation_not_90: if rotation_not_90:
for thetha_i in thetha: for thetha_i in thetha:
img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/' + im), img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im),
cv2.imread( cv2.imread(dir_seg + '/'+img_name + '.png'), thetha_i)
dir_seg + '/' + img_name + '.png'),
thetha_i)
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
img_max_rotated, img_max_rotated,
label_max_rotated, label_max_rotated,
@ -454,27 +496,66 @@ def provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
bluring(cv2.imread(dir_img + '/' + im), blur_i), bluring(cv2.imread(dir_img + '/' + im), blur_i),
cv2.imread(dir_seg + '/' + img_name + '.png'), cv2.imread(dir_seg + '/' + img_name + '.png'),
input_height, input_width, indexer=indexer) input_height, input_width, indexer=indexer)
if padding_black:
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
do_padding_black(cv2.imread(dir_img + '/' + im)),
do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')),
input_height, input_width, indexer=indexer)
if padding_white:
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
do_padding_white(cv2.imread(dir_img + '/'+im)),
do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')),
input_height, input_width, indexer=indexer)
if brightening:
for factor in brightness:
try:
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
do_brightening(dir_img + '/' +im, factor),
cv2.imread(dir_seg + '/' + img_name + '.png'),
input_height, input_width, indexer=indexer)
except:
pass
if scaling: if scaling:
for sc_ind in scales: for sc_ind in scales:
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
cv2.imread(dir_img + '/' + im), cv2.imread(dir_img + '/' + im) ,
cv2.imread(dir_seg + '/' + img_name + '.png'), cv2.imread(dir_seg + '/' + img_name + '.png'),
input_height, input_width, indexer=indexer, scaler=sc_ind) input_height, input_width, indexer=indexer, scaler=sc_ind)
if degrading:
for degrade_scale_ind in degrade_scales:
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind),
cv2.imread(dir_seg + '/' + img_name + '.png'),
input_height, input_width, indexer=indexer)
if binarization: if binarization:
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
otsu_copy(cv2.imread(dir_img + '/' + im)), otsu_copy(cv2.imread(dir_img + '/' + im)),
cv2.imread(dir_seg + '/' + img_name + '.png'), cv2.imread(dir_seg + '/' + img_name + '.png'),
input_height, input_width, indexer=indexer) input_height, input_width, indexer=indexer)
if scaling_brightness:
for sc_ind in scales:
for factor in brightness:
try:
indexer = get_patches_num_scale_new(dir_flow_train_imgs,
dir_flow_train_labels,
do_brightening(dir_img + '/' + im, factor)
,cv2.imread(dir_seg + '/' + img_name + '.png')
,input_height, input_width, indexer=indexer, scaler=sc_ind)
except:
pass
if scaling_bluring: if scaling_bluring:
for sc_ind in scales: for sc_ind in scales:
for blur_i in blur_k: for blur_i in blur_k:
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
bluring(cv2.imread(dir_img + '/' + im), blur_i), bluring(cv2.imread(dir_img + '/' + im), blur_i),
cv2.imread(dir_seg + '/' + img_name + '.png'), cv2.imread(dir_seg + '/' + img_name + '.png'),
input_height, input_width, indexer=indexer, input_height, input_width, indexer=indexer, scaler=sc_ind)
scaler=sc_ind)
if scaling_binarization: if scaling_binarization:
for sc_ind in scales: for sc_ind in scales:
@ -487,8 +568,6 @@ def provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
for sc_ind in scales: for sc_ind in scales:
for f_i in flip_index: for f_i in flip_index:
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
cv2.flip(cv2.imread(dir_img + '/' + im), f_i), cv2.flip( cv2.imread(dir_img + '/' + im), f_i),
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
f_i), input_height, input_width, indexer=indexer, scaler=sc_ind)
input_height, input_width, indexer=indexer,
scaler=sc_ind)

Loading…
Cancel
Save