mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-08 19:30:07 +02:00
first working update of branch
This commit is contained in:
parent
02b1436f39
commit
d27647a0f1
4 changed files with 452 additions and 151 deletions
|
@ -1,8 +1,9 @@
|
|||
{
|
||||
"n_classes" : 3,
|
||||
"model_name" : "hybrid_transformer_cnn",
|
||||
"n_classes" : 2,
|
||||
"n_epochs" : 2,
|
||||
"input_height" : 448,
|
||||
"input_width" : 672,
|
||||
"input_width" : 448,
|
||||
"weight_decay" : 1e-6,
|
||||
"n_batch" : 2,
|
||||
"learning_rate": 1e-4,
|
||||
|
@ -18,13 +19,21 @@
|
|||
"scaling_flip" : false,
|
||||
"rotation": false,
|
||||
"rotation_not_90": false,
|
||||
"num_patches_xy": [28, 28],
|
||||
"transformer_patchsize": 1,
|
||||
"blur_k" : ["blur","guass","median"],
|
||||
"scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
|
||||
"brightness" : [1.3, 1.5, 1.7, 2],
|
||||
"degrade_scales" : [0.2, 0.4],
|
||||
"flip_index" : [0, 1, -1],
|
||||
"thetha" : [10, -10],
|
||||
"continue_training": false,
|
||||
"index_start": 0,
|
||||
"dir_of_start_model": " ",
|
||||
"index_start" : 0,
|
||||
"dir_of_start_model" : " ",
|
||||
"weighted_loss": false,
|
||||
"is_loss_soft_dice": false,
|
||||
"data_is_provided": false,
|
||||
"dir_train": "/train",
|
||||
"dir_eval": "/eval",
|
||||
"dir_output": "/output"
|
||||
"dir_output": "/out"
|
||||
}
|
||||
|
|
179
models.py
179
models.py
|
@ -1,13 +1,81 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras.models import *
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.regularizers import l2
|
||||
|
||||
mlp_head_units = [2048, 1024]
|
||||
projection_dim = 64
|
||||
transformer_layers = 8
|
||||
num_heads = 4
|
||||
resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
MERGE_AXIS = -1
|
||||
|
||||
transformer_units = [
|
||||
projection_dim * 2,
|
||||
projection_dim,
|
||||
] # Size of the transformer layers
|
||||
def mlp(x, hidden_units, dropout_rate):
|
||||
for units in hidden_units:
|
||||
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
||||
x = layers.Dropout(dropout_rate)(x)
|
||||
return x
|
||||
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
print(tf.shape(images)[1],'images')
|
||||
print(self.patch_size,'self.patch_size')
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
print(patches.shape,patch_dims,'patch_dims')
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, num_patches, projection_dim):
|
||||
super(PatchEncoder, self).__init__()
|
||||
self.num_patches = num_patches
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(
|
||||
input_dim=num_patches, output_dim=projection_dim
|
||||
)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': self.num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
def one_side_pad(x):
|
||||
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
|
||||
if IMAGE_ORDERING == 'channels_first':
|
||||
|
@ -292,3 +360,114 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-
|
|||
model = Model(img_input, o)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||
IMAGE_ORDERING = 'channels_last'
|
||||
bn_axis=3
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x)
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = keras.Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
num_patches = x.shape[1]*x.shape[2]
|
||||
patches = Patches(patch_size)(x)
|
||||
# Encode patches.
|
||||
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
||||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = layers.MultiHeadAttention(
|
||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = layers.Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = layers.Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64])
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||
v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048)
|
||||
v1024_2048 = Activation('relu')(v1024_2048)
|
||||
|
||||
o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048)
|
||||
o = (concatenate([o, f4],axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o ,f3], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f2], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, f1], axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
|
||||
o = (concatenate([o, inputs],axis=MERGE_AXIS))
|
||||
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
|
||||
o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = Activation('relu')(o)
|
||||
|
||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
||||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
model = keras.Model(inputs=inputs, outputs=o)
|
||||
|
||||
return model
|
||||
|
|
132
train.py
132
train.py
|
@ -10,6 +10,7 @@ from utils import *
|
|||
from metrics import *
|
||||
from tensorflow.keras.models import load_model
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
|
||||
def configuration():
|
||||
|
@ -42,9 +43,13 @@ def config_params():
|
|||
learning_rate = 1e-4 # Set the learning rate.
|
||||
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
|
||||
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
|
||||
flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py.
|
||||
blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py.
|
||||
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py.
|
||||
flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json.
|
||||
blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json.
|
||||
padding_white = False # If true, white padding will be applied to the image.
|
||||
padding_black = False # If true, black padding will be applied to the image.
|
||||
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json.
|
||||
degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json.
|
||||
brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json.
|
||||
binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images.
|
||||
dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels".
|
||||
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
|
||||
|
@ -52,13 +57,18 @@ def config_params():
|
|||
pretraining = False # Set to true to load pretrained weights of ResNet50 encoder.
|
||||
scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image.
|
||||
scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image.
|
||||
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
|
||||
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
|
||||
thetha = [10, -10] # Rotate image by these angles for augmentation.
|
||||
blur_k = ['blur', 'gauss', 'median'] # Blur image for augmentation.
|
||||
scales = [0.5, 2] # Scale patches for augmentation.
|
||||
flip_index = [0, 1, -1] # Flip image for augmentation.
|
||||
thetha = None # Rotate image by these angles for augmentation.
|
||||
blur_k = None # Blur image for augmentation.
|
||||
scales = None # Scale patches for augmentation.
|
||||
degrade_scales = None # Degrade image for augmentation.
|
||||
brightness = None # Brighten image for augmentation.
|
||||
flip_index = None # Flip image for augmentation.
|
||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||
transformer_patchsize = None # Patch size of vision transformer patches.
|
||||
num_patches_xy = None # Number of patches for vision transformer.
|
||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||
|
@ -66,15 +76,19 @@ def config_params():
|
|||
|
||||
|
||||
@ex.automain
|
||||
def run(n_classes, n_epochs, input_height,
|
||||
def run(_config, n_classes, n_epochs, input_height,
|
||||
input_width, weight_decay, weighted_loss,
|
||||
index_start, dir_of_start_model, is_loss_soft_dice,
|
||||
n_batch, patches, augmentation, flip_aug,
|
||||
blur_aug, scaling, binarization,
|
||||
blur_k, scales, dir_train, data_is_provided,
|
||||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip, continue_training,
|
||||
flip_index, dir_eval, dir_output, pretraining, learning_rate):
|
||||
blur_aug, padding_white, padding_black, scaling, degrading,
|
||||
brightening, binarization, blur_k, scales, degrade_scales,
|
||||
brightness, dir_train, data_is_provided, scaling_bluring,
|
||||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||
thetha, scaling_flip, continue_training, transformer_patchsize,
|
||||
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate):
|
||||
|
||||
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
||||
if data_is_provided:
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
@ -121,23 +135,28 @@ def run(n_classes, n_epochs, input_height,
|
|||
|
||||
# set the gpu configuration
|
||||
configuration()
|
||||
|
||||
imgs_list=np.array(os.listdir(dir_img))
|
||||
segs_list=np.array(os.listdir(dir_seg))
|
||||
|
||||
imgs_list_test=np.array(os.listdir(dir_img_val))
|
||||
segs_list_test=np.array(os.listdir(dir_seg_val))
|
||||
|
||||
# writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height, input_width, blur_k, blur_aug,
|
||||
flip_aug, binarization, scaling, scales, flip_index,
|
||||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip,
|
||||
augmentation=augmentation, patches=patches)
|
||||
|
||||
provide_patches(dir_img_val, dir_seg_val, dir_flow_eval_imgs,
|
||||
dir_flow_eval_labels,
|
||||
input_height, input_width, blur_k, blur_aug,
|
||||
flip_aug, binarization, scaling, scales, flip_index,
|
||||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip,
|
||||
augmentation=False, patches=patches)
|
||||
provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels, input_height, input_width, blur_k,
|
||||
blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation,
|
||||
patches=patches)
|
||||
|
||||
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
|
||||
dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width,
|
||||
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches)
|
||||
|
||||
if weighted_loss:
|
||||
weights = np.zeros(n_classes)
|
||||
|
@ -166,38 +185,50 @@ def run(n_classes, n_epochs, input_height,
|
|||
weights = weights / float(np.sum(weights))
|
||||
|
||||
if continue_training:
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True,
|
||||
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True)
|
||||
if model_name=='resnet50_unet':
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
else:
|
||||
# get our model.
|
||||
index_start = 0
|
||||
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||
|
||||
# if you want to see the model structure just uncomment model summary.
|
||||
# model.summary()
|
||||
if model_name=='resnet50_unet':
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
||||
|
||||
# generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
|
||||
|
||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
|
@ -205,9 +236,12 @@ def run(n_classes, n_epochs, input_height,
|
|||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
model.save(dir_output + '/' + 'model_' + str(i))
|
||||
model.save(dir_output+'/'+'model_'+str(i))
|
||||
|
||||
with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
# os.system('rm -rf '+dir_train_flowing)
|
||||
# os.system('rm -rf '+dir_eval_flowing)
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
#os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
# model.save(dir_output+'/'+'model'+'.h5')
|
||||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
|
|
273
utils.py
273
utils.py
|
@ -9,6 +9,15 @@ from tqdm import tqdm
|
|||
import imutils
|
||||
import math
|
||||
|
||||
def do_brightening(img_in_dir, factor):
|
||||
im = Image.open(img_in_dir)
|
||||
enhancer = ImageEnhance.Brightness(im)
|
||||
out_img = enhancer.enhance(factor)
|
||||
out_img = out_img.convert('RGB')
|
||||
opencv_img = np.array(out_img)
|
||||
opencv_img = opencv_img[:,:,::-1].copy()
|
||||
return opencv_img
|
||||
|
||||
|
||||
def bluring(img_in, kind):
|
||||
if kind == 'gauss':
|
||||
|
@ -138,11 +147,11 @@ def IoU(Yi, y_predi):
|
|||
FP = np.sum((Yi != c) & (y_predi == c))
|
||||
FN = np.sum((Yi == c) & (y_predi != c))
|
||||
IoU = TP / float(TP + FP + FN)
|
||||
print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU))
|
||||
#print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU))
|
||||
IoUs.append(IoU)
|
||||
mIoU = np.mean(IoUs)
|
||||
print("_________________")
|
||||
print("Mean IoU: {:4.3f}".format(mIoU))
|
||||
#print("_________________")
|
||||
#print("Mean IoU: {:4.3f}".format(mIoU))
|
||||
return mIoU
|
||||
|
||||
|
||||
|
@ -241,124 +250,170 @@ def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer):
|
|||
return indexer
|
||||
|
||||
|
||||
def do_padding(img, label, height, width):
|
||||
height_new = img.shape[0]
|
||||
width_new = img.shape[1]
|
||||
def do_padding_white(img):
|
||||
img_org_h = img.shape[0]
|
||||
img_org_w = img.shape[1]
|
||||
|
||||
index_start_h = 4
|
||||
index_start_w = 4
|
||||
|
||||
img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1]+ 2*index_start_w, img.shape[2])) + 255
|
||||
img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
|
||||
|
||||
return img_padded.astype(float)
|
||||
|
||||
|
||||
def do_degrading(img, scale):
|
||||
img_org_h = img.shape[0]
|
||||
img_org_w = img.shape[1]
|
||||
|
||||
img_res = resize_image(img, int(img_org_h * scale), int(img_org_w * scale))
|
||||
|
||||
return resize_image(img_res, img_org_h, img_org_w)
|
||||
|
||||
|
||||
def do_padding_black(img):
|
||||
img_org_h = img.shape[0]
|
||||
img_org_w = img.shape[1]
|
||||
|
||||
index_start_h = 4
|
||||
index_start_w = 4
|
||||
|
||||
img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2]))
|
||||
img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
|
||||
|
||||
return img_padded.astype(float)
|
||||
|
||||
|
||||
def do_padding_label(img):
|
||||
img_org_h = img.shape[0]
|
||||
img_org_w = img.shape[1]
|
||||
|
||||
index_start_h = 4
|
||||
index_start_w = 4
|
||||
|
||||
img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2]))
|
||||
img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
|
||||
|
||||
return img_padded.astype(np.int16)
|
||||
|
||||
def do_padding(img, label, height, width):
|
||||
height_new=img.shape[0]
|
||||
width_new=img.shape[1]
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
|
||||
|
||||
if img.shape[0] < height:
|
||||
h_start = int(abs(height - img.shape[0]) / 2.)
|
||||
height_new = height
|
||||
|
||||
|
||||
if img.shape[1] < width:
|
||||
w_start = int(abs(width - img.shape[1]) / 2.)
|
||||
width_new = width
|
||||
|
||||
|
||||
img_new = np.ones((height_new, width_new, img.shape[2])).astype(float) * 255
|
||||
label_new = np.zeros((height_new, width_new, label.shape[2])).astype(float)
|
||||
|
||||
|
||||
img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :])
|
||||
label_new[h_start:h_start + label.shape[0], w_start:w_start + label.shape[1], :] = np.copy(label[:, :, :])
|
||||
|
||||
return img_new, label_new
|
||||
|
||||
return img_new,label_new
|
||||
|
||||
|
||||
def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler):
|
||||
if img.shape[0] < height or img.shape[1] < width:
|
||||
img, label = do_padding(img, label, height, width)
|
||||
|
||||
|
||||
img_h = img.shape[0]
|
||||
img_w = img.shape[1]
|
||||
|
||||
|
||||
height_scale = int(height * scaler)
|
||||
width_scale = int(width * scaler)
|
||||
|
||||
|
||||
|
||||
nxf = img_w / float(width_scale)
|
||||
nyf = img_h / float(height_scale)
|
||||
|
||||
|
||||
if nxf > int(nxf):
|
||||
nxf = int(nxf) + 1
|
||||
if nyf > int(nyf):
|
||||
nyf = int(nyf) + 1
|
||||
|
||||
|
||||
nxf = int(nxf)
|
||||
nyf = int(nyf)
|
||||
|
||||
|
||||
for i in range(nxf):
|
||||
for j in range(nyf):
|
||||
index_x_d = i * width_scale
|
||||
index_x_u = (i + 1) * width_scale
|
||||
|
||||
|
||||
index_y_d = j * height_scale
|
||||
index_y_u = (j + 1) * height_scale
|
||||
|
||||
|
||||
if index_x_u > img_w:
|
||||
index_x_u = img_w
|
||||
index_x_d = img_w - width_scale
|
||||
if index_y_u > img_h:
|
||||
index_y_u = img_h
|
||||
index_y_d = img_h - height_scale
|
||||
|
||||
|
||||
|
||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
|
||||
|
||||
img_patch = resize_image(img_patch, height, width)
|
||||
label_patch = resize_image(label_patch, height, width)
|
||||
|
||||
|
||||
cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch)
|
||||
cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch)
|
||||
indexer += 1
|
||||
|
||||
|
||||
return indexer
|
||||
|
||||
|
||||
def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, indexer, scaler):
|
||||
img = resize_image(img, int(img.shape[0] * scaler), int(img.shape[1] * scaler))
|
||||
label = resize_image(label, int(label.shape[0] * scaler), int(label.shape[1] * scaler))
|
||||
|
||||
|
||||
if img.shape[0] < height or img.shape[1] < width:
|
||||
img, label = do_padding(img, label, height, width)
|
||||
|
||||
|
||||
img_h = img.shape[0]
|
||||
img_w = img.shape[1]
|
||||
|
||||
|
||||
height_scale = int(height * 1)
|
||||
width_scale = int(width * 1)
|
||||
|
||||
|
||||
nxf = img_w / float(width_scale)
|
||||
nyf = img_h / float(height_scale)
|
||||
|
||||
|
||||
if nxf > int(nxf):
|
||||
nxf = int(nxf) + 1
|
||||
if nyf > int(nyf):
|
||||
nyf = int(nyf) + 1
|
||||
|
||||
|
||||
nxf = int(nxf)
|
||||
nyf = int(nyf)
|
||||
|
||||
|
||||
for i in range(nxf):
|
||||
for j in range(nyf):
|
||||
index_x_d = i * width_scale
|
||||
index_x_u = (i + 1) * width_scale
|
||||
|
||||
|
||||
index_y_d = j * height_scale
|
||||
index_y_u = (j + 1) * height_scale
|
||||
|
||||
|
||||
if index_x_u > img_w:
|
||||
index_x_u = img_w
|
||||
index_x_d = img_w - width_scale
|
||||
if index_y_u > img_h:
|
||||
index_y_u = img_h
|
||||
index_y_d = img_h - height_scale
|
||||
|
||||
|
||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
|
||||
# img_patch=resize_image(img_patch,height,width)
|
||||
# label_patch=resize_image(label_patch,height,width)
|
||||
|
||||
|
||||
cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch)
|
||||
cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch)
|
||||
indexer += 1
|
||||
|
@ -366,78 +421,65 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i
|
|||
return indexer
|
||||
|
||||
|
||||
def provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height, input_width, blur_k, blur_aug,
|
||||
flip_aug, binarization, scaling, scales, flip_index,
|
||||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip,
|
||||
augmentation=False, patches=False):
|
||||
imgs_cv_train = np.array(os.listdir(dir_img))
|
||||
segs_cv_train = np.array(os.listdir(dir_seg))
|
||||
|
||||
def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels, input_height, input_width, blur_k, blur_aug,
|
||||
padding_white, padding_black, flip_aug, binarization, scaling, degrading,
|
||||
brightening, scales, degrade_scales, brightness, flip_index,
|
||||
scaling_bluring, scaling_brightness, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip, augmentation=False, patches=False):
|
||||
|
||||
indexer = 0
|
||||
for im, seg_i in tqdm(zip(imgs_cv_train, segs_cv_train)):
|
||||
for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)):
|
||||
img_name = im.split('.')[0]
|
||||
if not patches:
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width))
|
||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width))
|
||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
||||
indexer += 1
|
||||
|
||||
|
||||
if augmentation:
|
||||
if flip_aug:
|
||||
for f_i in flip_index:
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.flip(cv2.imread(dir_img + '/' + im), f_i), input_height,
|
||||
input_width))
|
||||
|
||||
resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) )
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
|
||||
input_height, input_width))
|
||||
resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width))
|
||||
indexer += 1
|
||||
|
||||
if blur_aug:
|
||||
|
||||
if blur_aug:
|
||||
for blur_i in blur_k:
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||
(resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height,
|
||||
input_width)))
|
||||
|
||||
(resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width)))
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height,
|
||||
input_width))
|
||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
||||
indexer += 1
|
||||
|
||||
|
||||
if binarization:
|
||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||
resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width))
|
||||
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
||||
indexer += 1
|
||||
|
||||
|
||||
|
||||
if patches:
|
||||
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
|
||||
if augmentation:
|
||||
|
||||
if rotation:
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
rotation_90(cv2.imread(dir_img + '/' + im)),
|
||||
rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
rotation_90(cv2.imread(dir_img + '/' + im)),
|
||||
rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
if rotation_not_90:
|
||||
|
||||
for thetha_i in thetha:
|
||||
img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/' + im),
|
||||
cv2.imread(
|
||||
dir_seg + '/' + img_name + '.png'),
|
||||
thetha_i)
|
||||
img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im),
|
||||
cv2.imread(dir_seg + '/'+img_name + '.png'), thetha_i)
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
img_max_rotated,
|
||||
label_max_rotated,
|
||||
|
@ -448,47 +490,84 @@ def provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
|
|||
cv2.flip(cv2.imread(dir_img + '/' + im), f_i),
|
||||
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
|
||||
input_height, input_width, indexer=indexer)
|
||||
if blur_aug:
|
||||
if blur_aug:
|
||||
for blur_i in blur_k:
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
if scaling:
|
||||
input_height, input_width, indexer=indexer)
|
||||
if padding_black:
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
do_padding_black(cv2.imread(dir_img + '/' + im)),
|
||||
do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
if padding_white:
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
do_padding_white(cv2.imread(dir_img + '/'+im)),
|
||||
do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
if brightening:
|
||||
for factor in brightness:
|
||||
try:
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
do_brightening(dir_img + '/' +im, factor),
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer)
|
||||
except:
|
||||
pass
|
||||
if scaling:
|
||||
for sc_ind in scales:
|
||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
cv2.imread(dir_img + '/' + im),
|
||||
cv2.imread(dir_img + '/' + im) ,
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||
|
||||
if degrading:
|
||||
for degrade_scale_ind in degrade_scales:
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind),
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
if binarization:
|
||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
otsu_copy(cv2.imread(dir_img + '/' + im)),
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer)
|
||||
|
||||
if scaling_bluring:
|
||||
if scaling_brightness:
|
||||
for sc_ind in scales:
|
||||
for factor in brightness:
|
||||
try:
|
||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
do_brightening(dir_img + '/' + im, factor)
|
||||
,cv2.imread(dir_seg + '/' + img_name + '.png')
|
||||
,input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||
except:
|
||||
pass
|
||||
|
||||
if scaling_bluring:
|
||||
for sc_ind in scales:
|
||||
for blur_i in blur_k:
|
||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer,
|
||||
scaler=sc_ind)
|
||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||
|
||||
if scaling_binarization:
|
||||
if scaling_binarization:
|
||||
for sc_ind in scales:
|
||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
otsu_copy(cv2.imread(dir_img + '/' + im)),
|
||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||
|
||||
if scaling_flip:
|
||||
|
||||
if scaling_flip:
|
||||
for sc_ind in scales:
|
||||
for f_i in flip_index:
|
||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||
cv2.flip(cv2.imread(dir_img + '/' + im), f_i),
|
||||
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'),
|
||||
f_i),
|
||||
input_height, input_width, indexer=indexer,
|
||||
scaler=sc_ind)
|
||||
cv2.flip( cv2.imread(dir_img + '/' + im), f_i),
|
||||
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
|
||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue