From d27647a0f1c96b12b475add13142fdb91e7888de Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Tue, 16 Apr 2024 01:00:48 +0200 Subject: [PATCH] first working update of branch --- config_params.json | 19 +++- models.py | 179 +++++++++++++++++++++++++++++ train.py | 132 ++++++++++++++-------- utils.py | 273 +++++++++++++++++++++++++++++---------------- 4 files changed, 452 insertions(+), 151 deletions(-) diff --git a/config_params.json b/config_params.json index 7505a81..bd47a52 100644 --- a/config_params.json +++ b/config_params.json @@ -1,8 +1,9 @@ { - "n_classes" : 3, + "model_name" : "hybrid_transformer_cnn", + "n_classes" : 2, "n_epochs" : 2, "input_height" : 448, - "input_width" : 672, + "input_width" : 448, "weight_decay" : 1e-6, "n_batch" : 2, "learning_rate": 1e-4, @@ -18,13 +19,21 @@ "scaling_flip" : false, "rotation": false, "rotation_not_90": false, + "num_patches_xy": [28, 28], + "transformer_patchsize": 1, + "blur_k" : ["blur","guass","median"], + "scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4], + "brightness" : [1.3, 1.5, 1.7, 2], + "degrade_scales" : [0.2, 0.4], + "flip_index" : [0, 1, -1], + "thetha" : [10, -10], "continue_training": false, - "index_start": 0, - "dir_of_start_model": " ", + "index_start" : 0, + "dir_of_start_model" : " ", "weighted_loss": false, "is_loss_soft_dice": false, "data_is_provided": false, "dir_train": "/train", "dir_eval": "/eval", - "dir_output": "/output" + "dir_output": "/out" } diff --git a/models.py b/models.py index f06823e..f7a7ad8 100644 --- a/models.py +++ b/models.py @@ -1,13 +1,81 @@ +import tensorflow as tf +from tensorflow import keras from tensorflow.keras.models import * from tensorflow.keras.layers import * from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 +mlp_head_units = [2048, 1024] +projection_dim = 64 +transformer_layers = 8 +num_heads = 4 resnet50_Weights_path = './pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' IMAGE_ORDERING = 'channels_last' MERGE_AXIS = -1 +transformer_units = [ + projection_dim * 2, + projection_dim, +] # Size of the transformer layers +def mlp(x, hidden_units, dropout_rate): + for units in hidden_units: + x = layers.Dense(units, activation=tf.nn.gelu)(x) + x = layers.Dropout(dropout_rate)(x) + return x + +class Patches(layers.Layer): + def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs): + super(Patches, self).__init__() + self.patch_size = patch_size + + def call(self, images): + print(tf.shape(images)[1],'images') + print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + print(patches.shape,patch_dims,'patch_dims') + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size': self.patch_size, + }) + return config + +class PatchEncoder(layers.Layer): + def __init__(self, num_patches, projection_dim): + super(PatchEncoder, self).__init__() + self.num_patches = num_patches + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) + + def call(self, patch): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'num_patches': self.num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + + def one_side_pad(x): x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) if IMAGE_ORDERING == 'channels_first': @@ -292,3 +360,114 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e- model = Model(img_input, o) return model + + +def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False): + inputs = layers.Input(shape=(input_height, input_width, 3)) + IMAGE_ORDERING = 'channels_last' + bn_axis=3 + + x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(inputs) + x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x) + f1 = x + + x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) + x = Activation('relu')(x) + x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) + + x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) + x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') + x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') + f2 = one_side_pad(x) + + x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') + x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') + f3 = x + + x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') + x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') + f4 = x + + x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') + x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') + f5 = x + + if pretraining: + model = keras.Model(inputs, x).load_weights(resnet50_Weights_path) + + num_patches = x.shape[1]*x.shape[2] + patches = Patches(patch_size)(x) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + for _ in range(transformer_layers): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + )(x1, x1) + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = layers.Add()([x3, x2]) + + encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2], 64]) + + v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches) + v1024_2048 = (BatchNormalization(axis=bn_axis))(v1024_2048) + v1024_2048 = Activation('relu')(v1024_2048) + + o = (UpSampling2D( (2, 2), data_format=IMAGE_ORDERING))(v1024_2048) + o = (concatenate([o, f4],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o ,f3], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f2], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, f1], axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) + o = (concatenate([o, inputs],axis=MERGE_AXIS)) + o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) + o = (Conv2D(32, (3, 3), padding='valid', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay)))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = Activation('relu')(o) + + o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o) + o = (BatchNormalization(axis=bn_axis))(o) + o = (Activation('softmax'))(o) + + model = keras.Model(inputs=inputs, outputs=o) + + return model diff --git a/train.py b/train.py index 03faf46..6e6a172 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,7 @@ from utils import * from metrics import * from tensorflow.keras.models import load_model from tqdm import tqdm +import json def configuration(): @@ -42,9 +43,13 @@ def config_params(): learning_rate = 1e-4 # Set the learning rate. patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false. augmentation = False # To apply any kind of augmentation, this parameter must be set to true. - flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py. - blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py. - scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py. + flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json. + blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json. + padding_white = False # If true, white padding will be applied to the image. + padding_black = False # If true, black padding will be applied to the image. + scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json. + degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json. + brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json. binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images. dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels". dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels". @@ -52,13 +57,18 @@ def config_params(): pretraining = False # Set to true to load pretrained weights of ResNet50 encoder. scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image. scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image. + scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image. scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image. - thetha = [10, -10] # Rotate image by these angles for augmentation. - blur_k = ['blur', 'gauss', 'median'] # Blur image for augmentation. - scales = [0.5, 2] # Scale patches for augmentation. - flip_index = [0, 1, -1] # Flip image for augmentation. + thetha = None # Rotate image by these angles for augmentation. + blur_k = None # Blur image for augmentation. + scales = None # Scale patches for augmentation. + degrade_scales = None # Degrade image for augmentation. + brightness = None # Brighten image for augmentation. + flip_index = None # Flip image for augmentation. continue_training = False # Set to true if you would like to continue training an already trained a model. - index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. + transformer_patchsize = None # Patch size of vision transformer patches. + num_patches_xy = None # Number of patches for vision transformer. + index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3. dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model. is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false. weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false. @@ -66,15 +76,19 @@ def config_params(): @ex.automain -def run(n_classes, n_epochs, input_height, +def run(_config, n_classes, n_epochs, input_height, input_width, weight_decay, weighted_loss, index_start, dir_of_start_model, is_loss_soft_dice, n_batch, patches, augmentation, flip_aug, - blur_aug, scaling, binarization, - blur_k, scales, dir_train, data_is_provided, - scaling_bluring, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, continue_training, - flip_index, dir_eval, dir_output, pretraining, learning_rate): + blur_aug, padding_white, padding_black, scaling, degrading, + brightening, binarization, blur_k, scales, degrade_scales, + brightness, dir_train, data_is_provided, scaling_bluring, + scaling_brightness, scaling_binarization, rotation, rotation_not_90, + thetha, scaling_flip, continue_training, transformer_patchsize, + num_patches_xy, model_name, flip_index, dir_eval, dir_output, + pretraining, learning_rate): + + num_patches = num_patches_xy[0]*num_patches_xy[1] if data_is_provided: dir_train_flowing = os.path.join(dir_output, 'train') dir_eval_flowing = os.path.join(dir_output, 'eval') @@ -121,23 +135,28 @@ def run(n_classes, n_epochs, input_height, # set the gpu configuration configuration() + + imgs_list=np.array(os.listdir(dir_img)) + segs_list=np.array(os.listdir(dir_seg)) + + imgs_list_test=np.array(os.listdir(dir_img_val)) + segs_list_test=np.array(os.listdir(dir_seg_val)) # writing patches into a sub-folder in order to be flowed from directory. - provide_patches(dir_img, dir_seg, dir_flow_train_imgs, - dir_flow_train_labels, - input_height, input_width, blur_k, blur_aug, - flip_aug, binarization, scaling, scales, flip_index, - scaling_bluring, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, - augmentation=augmentation, patches=patches) - - provide_patches(dir_img_val, dir_seg_val, dir_flow_eval_imgs, - dir_flow_eval_labels, - input_height, input_width, blur_k, blur_aug, - flip_aug, binarization, scaling, scales, flip_index, - scaling_bluring, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, - augmentation=False, patches=patches) + provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs, + dir_flow_train_labels, input_height, input_width, blur_k, + blur_aug, padding_white, padding_black, flip_aug, binarization, + scaling, degrading, brightening, scales, degrade_scales, brightness, + flip_index, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation, + patches=patches) + + provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val, + dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width, + blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization, + scaling, degrading, brightening, scales, degrade_scales, brightness, + flip_index, scaling_bluring, scaling_brightness, scaling_binarization, + rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches) if weighted_loss: weights = np.zeros(n_classes) @@ -166,38 +185,50 @@ def run(n_classes, n_epochs, input_height, weights = weights / float(np.sum(weights)) if continue_training: - if is_loss_soft_dice: - model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) - if weighted_loss: - model = load_model(dir_of_start_model, compile=True, - custom_objects={'loss': weighted_categorical_crossentropy(weights)}) - if not is_loss_soft_dice and not weighted_loss: - model = load_model(dir_of_start_model, compile=True) + if model_name=='resnet50_unet': + if is_loss_soft_dice: + model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model(dir_of_start_model , compile=True) + elif model_name=='hybrid_transformer_cnn': + if is_loss_soft_dice: + model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) else: - # get our model. index_start = 0 - model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining) - - # if you want to see the model structure just uncomment model summary. - # model.summary() + if model_name=='resnet50_unet': + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + elif model_name=='hybrid_transformer_cnn': + model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining) + + #if you want to see the model structure just uncomment model summary. + #model.summary() + if not is_loss_soft_dice and not weighted_loss: model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if is_loss_soft_dice: + if is_loss_soft_dice: model.compile(loss=soft_dice_loss, optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - if weighted_loss: model.compile(loss=weighted_categorical_crossentropy(weights), optimizer=Adam(lr=learning_rate), metrics=['accuracy']) - + # generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch, input_height=input_height, input_width=input_width, n_classes=n_classes) val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch, input_height=input_height, input_width=input_width, n_classes=n_classes) - + + ##img_validation_patches = os.listdir(dir_flow_eval_imgs) + ##score_best=[] + ##score_best.append(0) for i in tqdm(range(index_start, n_epochs + index_start)): model.fit_generator( train_gen, @@ -205,9 +236,12 @@ def run(n_classes, n_epochs, input_height, validation_data=val_gen, validation_steps=1, epochs=1) - model.save(dir_output + '/' + 'model_' + str(i)) + model.save(dir_output+'/'+'model_'+str(i)) + + with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp: + json.dump(_config, fp) # encode dict into JSON - # os.system('rm -rf '+dir_train_flowing) - # os.system('rm -rf '+dir_eval_flowing) + #os.system('rm -rf '+dir_train_flowing) + #os.system('rm -rf '+dir_eval_flowing) - # model.save(dir_output+'/'+'model'+'.h5') + #model.save(dir_output+'/'+'model'+'.h5') diff --git a/utils.py b/utils.py index 7c65f18..c2786ec 100644 --- a/utils.py +++ b/utils.py @@ -9,6 +9,15 @@ from tqdm import tqdm import imutils import math +def do_brightening(img_in_dir, factor): + im = Image.open(img_in_dir) + enhancer = ImageEnhance.Brightness(im) + out_img = enhancer.enhance(factor) + out_img = out_img.convert('RGB') + opencv_img = np.array(out_img) + opencv_img = opencv_img[:,:,::-1].copy() + return opencv_img + def bluring(img_in, kind): if kind == 'gauss': @@ -138,11 +147,11 @@ def IoU(Yi, y_predi): FP = np.sum((Yi != c) & (y_predi == c)) FN = np.sum((Yi == c) & (y_predi != c)) IoU = TP / float(TP + FP + FN) - print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU)) + #print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c, TP, FP, FN, IoU)) IoUs.append(IoU) mIoU = np.mean(IoUs) - print("_________________") - print("Mean IoU: {:4.3f}".format(mIoU)) + #print("_________________") + #print("Mean IoU: {:4.3f}".format(mIoU)) return mIoU @@ -241,124 +250,170 @@ def get_patches(dir_img_f, dir_seg_f, img, label, height, width, indexer): return indexer -def do_padding(img, label, height, width): - height_new = img.shape[0] - width_new = img.shape[1] +def do_padding_white(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1]+ 2*index_start_w, img.shape[2])) + 255 + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(float) + + +def do_degrading(img, scale): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + img_res = resize_image(img, int(img_org_h * scale), int(img_org_w * scale)) + + return resize_image(img_res, img_org_h, img_org_w) + + +def do_padding_black(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2])) + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(float) + + +def do_padding_label(img): + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + index_start_h = 4 + index_start_w = 4 + + img_padded = np.zeros((img.shape[0] + 2*index_start_h, img.shape[1] + 2*index_start_w, img.shape[2])) + img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] + + return img_padded.astype(np.int16) +def do_padding(img, label, height, width): + height_new=img.shape[0] + width_new=img.shape[1] + h_start = 0 w_start = 0 - + if img.shape[0] < height: h_start = int(abs(height - img.shape[0]) / 2.) height_new = height - + if img.shape[1] < width: w_start = int(abs(width - img.shape[1]) / 2.) width_new = width - + img_new = np.ones((height_new, width_new, img.shape[2])).astype(float) * 255 label_new = np.zeros((height_new, width_new, label.shape[2])).astype(float) - + img_new[h_start:h_start + img.shape[0], w_start:w_start + img.shape[1], :] = np.copy(img[:, :, :]) label_new[h_start:h_start + label.shape[0], w_start:w_start + label.shape[1], :] = np.copy(label[:, :, :]) - - return img_new, label_new + + return img_new,label_new def get_patches_num_scale(dir_img_f, dir_seg_f, img, label, height, width, indexer, n_patches, scaler): if img.shape[0] < height or img.shape[1] < width: img, label = do_padding(img, label, height, width) - + img_h = img.shape[0] img_w = img.shape[1] - + height_scale = int(height * scaler) width_scale = int(width * scaler) - + + nxf = img_w / float(width_scale) nyf = img_h / float(height_scale) - + if nxf > int(nxf): nxf = int(nxf) + 1 if nyf > int(nyf): nyf = int(nyf) + 1 - + nxf = int(nxf) nyf = int(nyf) - + for i in range(nxf): for j in range(nyf): index_x_d = i * width_scale index_x_u = (i + 1) * width_scale - + index_y_d = j * height_scale index_y_u = (j + 1) * height_scale - + if index_x_u > img_w: index_x_u = img_w index_x_d = img_w - width_scale if index_y_u > img_h: index_y_u = img_h index_y_d = img_h - height_scale - + + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] - + img_patch = resize_image(img_patch, height, width) label_patch = resize_image(label_patch, height, width) - + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) indexer += 1 - + return indexer def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, indexer, scaler): img = resize_image(img, int(img.shape[0] * scaler), int(img.shape[1] * scaler)) label = resize_image(label, int(label.shape[0] * scaler), int(label.shape[1] * scaler)) - + if img.shape[0] < height or img.shape[1] < width: img, label = do_padding(img, label, height, width) - + img_h = img.shape[0] img_w = img.shape[1] - + height_scale = int(height * 1) width_scale = int(width * 1) - + nxf = img_w / float(width_scale) nyf = img_h / float(height_scale) - + if nxf > int(nxf): nxf = int(nxf) + 1 if nyf > int(nyf): nyf = int(nyf) + 1 - + nxf = int(nxf) nyf = int(nyf) - + for i in range(nxf): for j in range(nyf): index_x_d = i * width_scale index_x_u = (i + 1) * width_scale - + index_y_d = j * height_scale index_y_u = (j + 1) * height_scale - + if index_x_u > img_w: index_x_u = img_w index_x_d = img_w - width_scale if index_y_u > img_h: index_y_u = img_h index_y_d = img_h - height_scale - + img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_patch = label[index_y_d:index_y_u, index_x_d:index_x_u, :] - - # img_patch=resize_image(img_patch,height,width) - # label_patch=resize_image(label_patch,height,width) - + cv2.imwrite(dir_img_f + '/img_' + str(indexer) + '.png', img_patch) cv2.imwrite(dir_seg_f + '/img_' + str(indexer) + '.png', label_patch) indexer += 1 @@ -366,78 +421,65 @@ def get_patches_num_scale_new(dir_img_f, dir_seg_f, img, label, height, width, i return indexer -def provide_patches(dir_img, dir_seg, dir_flow_train_imgs, - dir_flow_train_labels, - input_height, input_width, blur_k, blur_aug, - flip_aug, binarization, scaling, scales, flip_index, - scaling_bluring, scaling_binarization, rotation, - rotation_not_90, thetha, scaling_flip, - augmentation=False, patches=False): - imgs_cv_train = np.array(os.listdir(dir_img)) - segs_cv_train = np.array(os.listdir(dir_seg)) - +def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow_train_imgs, + dir_flow_train_labels, input_height, input_width, blur_k, blur_aug, + padding_white, padding_black, flip_aug, binarization, scaling, degrading, + brightening, scales, degrade_scales, brightness, flip_index, + scaling_bluring, scaling_brightness, scaling_binarization, rotation, + rotation_not_90, thetha, scaling_flip, augmentation=False, patches=False): + indexer = 0 - for im, seg_i in tqdm(zip(imgs_cv_train, segs_cv_train)): + for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)): img_name = im.split('.')[0] if not patches: - cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) - cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) + cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width)) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) indexer += 1 - + if augmentation: if flip_aug: for f_i in flip_index: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - resize_image(cv2.flip(cv2.imread(dir_img + '/' + im), f_i), input_height, - input_width)) - + resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), - input_height, input_width)) + resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width)) indexer += 1 - - if blur_aug: + + if blur_aug: for blur_i in blur_k: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', - (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, - input_width))) - + (resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width))) + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', - resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, - input_width)) + resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) indexer += 1 - + if binarization: cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width)) - + cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width)) indexer += 1 - + + if patches: - indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width, indexer=indexer) - + if augmentation: - if rotation: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, - rotation_90(cv2.imread(dir_img + '/' + im)), - rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')), - input_height, input_width, indexer=indexer) - + rotation_90(cv2.imread(dir_img + '/' + im)), + rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')), + input_height, input_width, indexer=indexer) + if rotation_not_90: - for thetha_i in thetha: - img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/' + im), - cv2.imread( - dir_seg + '/' + img_name + '.png'), - thetha_i) + img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im), + cv2.imread(dir_seg + '/'+img_name + '.png'), thetha_i) indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, img_max_rotated, label_max_rotated, @@ -448,47 +490,84 @@ def provide_patches(dir_img, dir_seg, dir_flow_train_imgs, cv2.flip(cv2.imread(dir_img + '/' + im), f_i), cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width, indexer=indexer) - if blur_aug: + if blur_aug: for blur_i in blur_k: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, bluring(cv2.imread(dir_img + '/' + im), blur_i), cv2.imread(dir_seg + '/' + img_name + '.png'), - input_height, input_width, indexer=indexer) - - if scaling: + input_height, input_width, indexer=indexer) + if padding_black: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_padding_black(cv2.imread(dir_img + '/' + im)), + do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')), + input_height, input_width, indexer=indexer) + + if padding_white: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_padding_white(cv2.imread(dir_img + '/'+im)), + do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')), + input_height, input_width, indexer=indexer) + + if brightening: + for factor in brightness: + try: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_brightening(dir_img + '/' +im, factor), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer) + except: + pass + if scaling: for sc_ind in scales: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, - cv2.imread(dir_img + '/' + im), + cv2.imread(dir_img + '/' + im) , cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width, indexer=indexer, scaler=sc_ind) + + if degrading: + for degrade_scale_ind in degrade_scales: + indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, + do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind), + cv2.imread(dir_seg + '/' + img_name + '.png'), + input_height, input_width, indexer=indexer) + if binarization: indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels, otsu_copy(cv2.imread(dir_img + '/' + im)), cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width, indexer=indexer) - if scaling_bluring: + if scaling_brightness: + for sc_ind in scales: + for factor in brightness: + try: + indexer = get_patches_num_scale_new(dir_flow_train_imgs, + dir_flow_train_labels, + do_brightening(dir_img + '/' + im, factor) + ,cv2.imread(dir_seg + '/' + img_name + '.png') + ,input_height, input_width, indexer=indexer, scaler=sc_ind) + except: + pass + + if scaling_bluring: for sc_ind in scales: for blur_i in blur_k: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, bluring(cv2.imread(dir_img + '/' + im), blur_i), cv2.imread(dir_seg + '/' + img_name + '.png'), - input_height, input_width, indexer=indexer, - scaler=sc_ind) + input_height, input_width, indexer=indexer, scaler=sc_ind) - if scaling_binarization: + if scaling_binarization: for sc_ind in scales: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, otsu_copy(cv2.imread(dir_img + '/' + im)), cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width, indexer=indexer, scaler=sc_ind) - - if scaling_flip: + + if scaling_flip: for sc_ind in scales: for f_i in flip_index: indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels, - cv2.flip(cv2.imread(dir_img + '/' + im), f_i), - cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), - f_i), - input_height, input_width, indexer=indexer, - scaler=sc_ind) + cv2.flip( cv2.imread(dir_img + '/' + im), f_i), + cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), + input_height, input_width, indexer=indexer, scaler=sc_ind)