mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 11:50:04 +02:00
adding enhancement training
This commit is contained in:
parent
dbb84507ed
commit
38db3e9289
5 changed files with 119 additions and 68 deletions
|
@ -1,15 +1,15 @@
|
||||||
{
|
{
|
||||||
"model_name" : "resnet50_unet",
|
"model_name" : "resnet50_unet",
|
||||||
"task": "classification",
|
"task": "enhancement",
|
||||||
"n_classes" : 2,
|
"n_classes" : 3,
|
||||||
"n_epochs" : 7,
|
"n_epochs" : 3,
|
||||||
"input_height" : 224,
|
"input_height" : 448,
|
||||||
"input_width" : 224,
|
"input_width" : 448,
|
||||||
"weight_decay" : 1e-6,
|
"weight_decay" : 1e-6,
|
||||||
"n_batch" : 6,
|
"n_batch" : 3,
|
||||||
"learning_rate": 1e-4,
|
"learning_rate": 1e-4,
|
||||||
"f1_threshold_classification": 0.8,
|
"f1_threshold_classification": 0.8,
|
||||||
"patches" : false,
|
"patches" : true,
|
||||||
"pretraining" : true,
|
"pretraining" : true,
|
||||||
"augmentation" : false,
|
"augmentation" : false,
|
||||||
"flip_aug" : false,
|
"flip_aug" : false,
|
||||||
|
@ -35,7 +35,7 @@
|
||||||
"weighted_loss": false,
|
"weighted_loss": false,
|
||||||
"is_loss_soft_dice": false,
|
"is_loss_soft_dice": false,
|
||||||
"data_is_provided": false,
|
"data_is_provided": false,
|
||||||
"dir_train": "/home/vahid/Downloads/image_classification_data/train",
|
"dir_train": "./training_data_sample_enhancement",
|
||||||
"dir_eval": "/home/vahid/Downloads/image_classification_data/eval",
|
"dir_eval": "./eval",
|
||||||
"dir_output": "/home/vahid/Downloads/image_classification_data/output"
|
"dir_output": "./out"
|
||||||
}
|
}
|
||||||
|
|
31
gt_for_enhancement_creator.py
Normal file
31
gt_for_enhancement_creator.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
|
||||||
|
def resize_image(seg_in, input_height, input_width):
|
||||||
|
return cv2.resize(seg_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
|
||||||
|
dir_imgs = './training_data_sample_enhancement/images'
|
||||||
|
dir_out_imgs = './training_data_sample_enhancement/images_gt'
|
||||||
|
dir_out_labs = './training_data_sample_enhancement/labels_gt'
|
||||||
|
|
||||||
|
ls_imgs = os.listdir(dir_imgs)
|
||||||
|
|
||||||
|
|
||||||
|
ls_scales = [ 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]
|
||||||
|
|
||||||
|
|
||||||
|
for img in ls_imgs:
|
||||||
|
img_name = img.split('.')[0]
|
||||||
|
img_type = img.split('.')[1]
|
||||||
|
image = cv2.imread(os.path.join(dir_imgs, img))
|
||||||
|
for i, scale in enumerate(ls_scales):
|
||||||
|
height_sc = int(image.shape[0]*scale)
|
||||||
|
width_sc = int(image.shape[1]*scale)
|
||||||
|
|
||||||
|
image_down_scaled = resize_image(image, height_sc, width_sc)
|
||||||
|
image_back_to_org_scale = resize_image(image_down_scaled, image.shape[0], image.shape[1])
|
||||||
|
|
||||||
|
cv2.imwrite(os.path.join(dir_out_imgs, img_name+'_'+str(i)+'.'+img_type), image_back_to_org_scale)
|
||||||
|
cv2.imwrite(os.path.join(dir_out_labs, img_name+'_'+str(i)+'.'+img_type), image)
|
||||||
|
|
27
models.py
27
models.py
|
@ -168,7 +168,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
assert input_height % 32 == 0
|
assert input_height % 32 == 0
|
||||||
assert input_width % 32 == 0
|
assert input_width % 32 == 0
|
||||||
|
|
||||||
|
@ -259,14 +259,17 @@ def resnet50_unet_light(n_classes, input_height=224, input_width=224, weight_dec
|
||||||
o = Activation('relu')(o)
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
if task == "segmentation":
|
||||||
o = (Activation('softmax'))(o)
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = (Activation('softmax'))(o)
|
||||||
|
else:
|
||||||
|
o = (Activation('sigmoid'))(o)
|
||||||
|
|
||||||
model = Model(img_input, o)
|
model = Model(img_input, o)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-6, pretraining=False):
|
def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
assert input_height % 32 == 0
|
assert input_height % 32 == 0
|
||||||
assert input_width % 32 == 0
|
assert input_width % 32 == 0
|
||||||
|
|
||||||
|
@ -354,15 +357,18 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, weight_decay=1e-
|
||||||
o = Activation('relu')(o)
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING, kernel_regularizer=l2(weight_decay))(o)
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
if task == "segmentation":
|
||||||
o = (Activation('softmax'))(o)
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = (Activation('softmax'))(o)
|
||||||
|
else:
|
||||||
|
o = (Activation('sigmoid'))(o)
|
||||||
|
|
||||||
model = Model(img_input, o)
|
model = Model(img_input, o)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
def vit_resnet50_unet(n_classes, patch_size, num_patches, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||||
IMAGE_ORDERING = 'channels_last'
|
IMAGE_ORDERING = 'channels_last'
|
||||||
bn_axis=3
|
bn_axis=3
|
||||||
|
@ -465,8 +471,11 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_
|
||||||
o = Activation('relu')(o)
|
o = Activation('relu')(o)
|
||||||
|
|
||||||
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
o = Conv2D(n_classes, (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(o)
|
||||||
o = (BatchNormalization(axis=bn_axis))(o)
|
if task == "segmentation":
|
||||||
o = (Activation('softmax'))(o)
|
o = (BatchNormalization(axis=bn_axis))(o)
|
||||||
|
o = (Activation('softmax'))(o)
|
||||||
|
else:
|
||||||
|
o = (Activation('sigmoid'))(o)
|
||||||
|
|
||||||
model = Model(inputs=inputs, outputs=o)
|
model = Model(inputs=inputs, outputs=o)
|
||||||
|
|
||||||
|
|
45
train.py
45
train.py
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.compat.v1.keras.backend import set_session
|
from tensorflow.compat.v1.keras.backend import set_session
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -91,7 +92,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
|
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
|
||||||
pretraining, learning_rate, task, f1_threshold_classification):
|
pretraining, learning_rate, task, f1_threshold_classification):
|
||||||
|
|
||||||
if task == "segmentation":
|
if task == "segmentation" or "enhancement":
|
||||||
|
|
||||||
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
||||||
if data_is_provided:
|
if data_is_provided:
|
||||||
|
@ -153,7 +154,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
blur_aug, padding_white, padding_black, flip_aug, binarization,
|
blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation,
|
rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=augmentation,
|
||||||
patches=patches)
|
patches=patches)
|
||||||
|
|
||||||
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
|
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
|
||||||
|
@ -161,7 +162,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
|
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches)
|
rotation, rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=patches)
|
||||||
|
|
||||||
if weighted_loss:
|
if weighted_loss:
|
||||||
weights = np.zeros(n_classes)
|
weights = np.zeros(n_classes)
|
||||||
|
@ -191,45 +192,49 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
|
|
||||||
if continue_training:
|
if continue_training:
|
||||||
if model_name=='resnet50_unet':
|
if model_name=='resnet50_unet':
|
||||||
if is_loss_soft_dice:
|
if is_loss_soft_dice and task == "segmentation":
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||||
if weighted_loss:
|
if weighted_loss and task == "segmentation":
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
model = load_model(dir_of_start_model , compile=True)
|
model = load_model(dir_of_start_model , compile=True)
|
||||||
elif model_name=='hybrid_transformer_cnn':
|
elif model_name=='hybrid_transformer_cnn':
|
||||||
if is_loss_soft_dice:
|
if is_loss_soft_dice and task == "segmentation":
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||||
if weighted_loss:
|
if weighted_loss and task == "segmentation":
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
if model_name=='resnet50_unet':
|
if model_name=='resnet50_unet':
|
||||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||||
elif model_name=='hybrid_transformer_cnn':
|
elif model_name=='hybrid_transformer_cnn':
|
||||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining)
|
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
||||||
|
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
#model.summary()
|
#model.summary()
|
||||||
|
|
||||||
|
if task == "segmentation":
|
||||||
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
|
model.compile(loss='categorical_crossentropy',
|
||||||
|
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||||
|
if is_loss_soft_dice:
|
||||||
|
model.compile(loss=soft_dice_loss,
|
||||||
|
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||||
|
if weighted_loss:
|
||||||
|
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||||
|
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||||
|
elif task == "enhancement":
|
||||||
|
model.compile(loss='mean_squared_error',
|
||||||
|
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||||
|
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
|
||||||
model.compile(loss='categorical_crossentropy',
|
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
|
||||||
if is_loss_soft_dice:
|
|
||||||
model.compile(loss=soft_dice_loss,
|
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
|
||||||
if weighted_loss:
|
|
||||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
|
||||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
|
||||||
|
|
||||||
# generating train and evaluation data
|
# generating train and evaluation data
|
||||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
||||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
||||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)
|
||||||
|
|
||||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||||
##score_best=[]
|
##score_best=[]
|
||||||
|
|
62
utils.py
62
utils.py
|
@ -268,7 +268,7 @@ def IoU(Yi, y_predi):
|
||||||
return mIoU
|
return mIoU
|
||||||
|
|
||||||
|
|
||||||
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes):
|
def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_classes, task='segmentation'):
|
||||||
c = 0
|
c = 0
|
||||||
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
|
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
|
||||||
random.shuffle(n)
|
random.shuffle(n)
|
||||||
|
@ -277,8 +277,6 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c
|
||||||
mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float')
|
mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float')
|
||||||
|
|
||||||
for i in range(c, c + batch_size): # initially from 0 to 16, c = 0.
|
for i in range(c, c + batch_size): # initially from 0 to 16, c = 0.
|
||||||
# print(img_folder+'/'+n[i])
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filename = n[i].split('.')[0]
|
filename = n[i].split('.')[0]
|
||||||
|
|
||||||
|
@ -287,11 +285,14 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c
|
||||||
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
||||||
|
|
||||||
img[i - c] = train_img # add to array - img[0], img[1], and so on.
|
img[i - c] = train_img # add to array - img[0], img[1], and so on.
|
||||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
if task == "segmentation":
|
||||||
# print(mask_folder+'/'+filename+'.png')
|
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
||||||
# print(train_mask.shape)
|
train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width,
|
||||||
train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width,
|
n_classes)
|
||||||
n_classes)
|
elif task == "enhancement":
|
||||||
|
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255.
|
||||||
|
train_mask = resize_image(train_mask, input_height, input_width)
|
||||||
|
|
||||||
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
||||||
|
|
||||||
mask[i - c] = train_mask
|
mask[i - c] = train_mask
|
||||||
|
@ -539,14 +540,19 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
padding_white, padding_black, flip_aug, binarization, scaling, degrading,
|
padding_white, padding_black, flip_aug, binarization, scaling, degrading,
|
||||||
brightening, scales, degrade_scales, brightness, flip_index,
|
brightening, scales, degrade_scales, brightness, flip_index,
|
||||||
scaling_bluring, scaling_brightness, scaling_binarization, rotation,
|
scaling_bluring, scaling_brightness, scaling_binarization, rotation,
|
||||||
rotation_not_90, thetha, scaling_flip, augmentation=False, patches=False):
|
rotation_not_90, thetha, scaling_flip, task, augmentation=False, patches=False):
|
||||||
|
|
||||||
indexer = 0
|
indexer = 0
|
||||||
for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)):
|
for im, seg_i in tqdm(zip(imgs_list_train, segs_list_train)):
|
||||||
img_name = im.split('.')[0]
|
img_name = im.split('.')[0]
|
||||||
|
if task == "segmentation":
|
||||||
|
dir_of_label_file = os.path.join(dir_seg, img_name + '.png')
|
||||||
|
elif task=="enhancement":
|
||||||
|
dir_of_label_file = os.path.join(dir_seg, im)
|
||||||
|
|
||||||
if not patches:
|
if not patches:
|
||||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width))
|
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_img + '/' + im), input_height, input_width))
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png', resize_image(cv2.imread(dir_of_label_file), input_height, input_width))
|
||||||
indexer += 1
|
indexer += 1
|
||||||
|
|
||||||
if augmentation:
|
if augmentation:
|
||||||
|
@ -556,7 +562,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) )
|
resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) )
|
||||||
|
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||||
resize_image(cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i), input_height, input_width))
|
resize_image(cv2.flip(cv2.imread(dir_of_label_file), f_i), input_height, input_width))
|
||||||
indexer += 1
|
indexer += 1
|
||||||
|
|
||||||
if blur_aug:
|
if blur_aug:
|
||||||
|
@ -565,7 +571,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
(resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width)))
|
(resize_image(bluring(cv2.imread(dir_img + '/' + im), blur_i), input_height, input_width)))
|
||||||
|
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
resize_image(cv2.imread(dir_of_label_file), input_height, input_width))
|
||||||
indexer += 1
|
indexer += 1
|
||||||
|
|
||||||
if binarization:
|
if binarization:
|
||||||
|
@ -573,26 +579,26 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width))
|
resize_image(otsu_copy(cv2.imread(dir_img + '/' + im)), input_height, input_width))
|
||||||
|
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||||
resize_image(cv2.imread(dir_seg + '/' + img_name + '.png'), input_height, input_width))
|
resize_image(cv2.imread(dir_of_label_file), input_height, input_width))
|
||||||
indexer += 1
|
indexer += 1
|
||||||
|
|
||||||
|
|
||||||
if patches:
|
if patches:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
cv2.imread(dir_img + '/' + im), cv2.imread(dir_seg + '/' + img_name + '.png'),
|
cv2.imread(dir_img + '/' + im), cv2.imread(dir_of_label_file),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
|
|
||||||
if augmentation:
|
if augmentation:
|
||||||
if rotation:
|
if rotation:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
rotation_90(cv2.imread(dir_img + '/' + im)),
|
rotation_90(cv2.imread(dir_img + '/' + im)),
|
||||||
rotation_90(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
rotation_90(cv2.imread(dir_of_label_file)),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
|
|
||||||
if rotation_not_90:
|
if rotation_not_90:
|
||||||
for thetha_i in thetha:
|
for thetha_i in thetha:
|
||||||
img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im),
|
img_max_rotated, label_max_rotated = rotation_not_90_func(cv2.imread(dir_img + '/'+im),
|
||||||
cv2.imread(dir_seg + '/'+img_name + '.png'), thetha_i)
|
cv2.imread(dir_of_label_file), thetha_i)
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
img_max_rotated,
|
img_max_rotated,
|
||||||
label_max_rotated,
|
label_max_rotated,
|
||||||
|
@ -601,24 +607,24 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
for f_i in flip_index:
|
for f_i in flip_index:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
cv2.flip(cv2.imread(dir_img + '/' + im), f_i),
|
cv2.flip(cv2.imread(dir_img + '/' + im), f_i),
|
||||||
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
|
cv2.flip(cv2.imread(dir_of_label_file), f_i),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
if blur_aug:
|
if blur_aug:
|
||||||
for blur_i in blur_k:
|
for blur_i in blur_k:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
cv2.imread(dir_of_label_file),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
if padding_black:
|
if padding_black:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
do_padding_black(cv2.imread(dir_img + '/' + im)),
|
do_padding_black(cv2.imread(dir_img + '/' + im)),
|
||||||
do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
do_padding_label(cv2.imread(dir_of_label_file)),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
|
|
||||||
if padding_white:
|
if padding_white:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
do_padding_white(cv2.imread(dir_img + '/'+im)),
|
do_padding_white(cv2.imread(dir_img + '/'+im)),
|
||||||
do_padding_label(cv2.imread(dir_seg + '/' + img_name + '.png')),
|
do_padding_label(cv2.imread(dir_of_label_file)),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
|
|
||||||
if brightening:
|
if brightening:
|
||||||
|
@ -626,7 +632,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
try:
|
try:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
do_brightening(dir_img + '/' +im, factor),
|
do_brightening(dir_img + '/' +im, factor),
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
cv2.imread(dir_of_label_file),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
@ -634,20 +640,20 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
for sc_ind in scales:
|
for sc_ind in scales:
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
cv2.imread(dir_img + '/' + im) ,
|
cv2.imread(dir_img + '/' + im) ,
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
cv2.imread(dir_of_label_file),
|
||||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||||
|
|
||||||
if degrading:
|
if degrading:
|
||||||
for degrade_scale_ind in degrade_scales:
|
for degrade_scale_ind in degrade_scales:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind),
|
do_degrading(cv2.imread(dir_img + '/' + im), degrade_scale_ind),
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
cv2.imread(dir_of_label_file),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
|
|
||||||
if binarization:
|
if binarization:
|
||||||
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
otsu_copy(cv2.imread(dir_img + '/' + im)),
|
otsu_copy(cv2.imread(dir_img + '/' + im)),
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
cv2.imread(dir_of_label_file),
|
||||||
input_height, input_width, indexer=indexer)
|
input_height, input_width, indexer=indexer)
|
||||||
|
|
||||||
if scaling_brightness:
|
if scaling_brightness:
|
||||||
|
@ -657,7 +663,7 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs,
|
indexer = get_patches_num_scale_new(dir_flow_train_imgs,
|
||||||
dir_flow_train_labels,
|
dir_flow_train_labels,
|
||||||
do_brightening(dir_img + '/' + im, factor)
|
do_brightening(dir_img + '/' + im, factor)
|
||||||
,cv2.imread(dir_seg + '/' + img_name + '.png')
|
,cv2.imread(dir_of_label_file)
|
||||||
,input_height, input_width, indexer=indexer, scaler=sc_ind)
|
,input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
@ -667,14 +673,14 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
for blur_i in blur_k:
|
for blur_i in blur_k:
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
bluring(cv2.imread(dir_img + '/' + im), blur_i),
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
cv2.imread(dir_of_label_file),
|
||||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||||
|
|
||||||
if scaling_binarization:
|
if scaling_binarization:
|
||||||
for sc_ind in scales:
|
for sc_ind in scales:
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
otsu_copy(cv2.imread(dir_img + '/' + im)),
|
otsu_copy(cv2.imread(dir_img + '/' + im)),
|
||||||
cv2.imread(dir_seg + '/' + img_name + '.png'),
|
cv2.imread(dir_of_label_file),
|
||||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||||
|
|
||||||
if scaling_flip:
|
if scaling_flip:
|
||||||
|
@ -682,5 +688,5 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
for f_i in flip_index:
|
for f_i in flip_index:
|
||||||
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
indexer = get_patches_num_scale_new(dir_flow_train_imgs, dir_flow_train_labels,
|
||||||
cv2.flip( cv2.imread(dir_img + '/' + im), f_i),
|
cv2.flip( cv2.imread(dir_img + '/' + im), f_i),
|
||||||
cv2.flip(cv2.imread(dir_seg + '/' + img_name + '.png'), f_i),
|
cv2.flip(cv2.imread(dir_of_label_file), f_i),
|
||||||
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
input_height, input_width, indexer=indexer, scaler=sc_ind)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue