mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-30 22:20:02 +02:00
Merge pull request #15 from vahidrezanezhad/master
continue training, loss functions, rotation and ...
This commit is contained in:
commit
75dc5f3177
3 changed files with 354 additions and 141 deletions
|
@ -1,24 +1,30 @@
|
|||
{
|
||||
"n_classes" : 2,
|
||||
"n_classes" : 3,
|
||||
"n_epochs" : 2,
|
||||
"input_height" : 448,
|
||||
"input_width" : 896,
|
||||
"input_width" : 672,
|
||||
"weight_decay" : 1e-6,
|
||||
"n_batch" : 1,
|
||||
"n_batch" : 2,
|
||||
"learning_rate": 1e-4,
|
||||
"patches" : true,
|
||||
"pretraining" : true,
|
||||
"augmentation" : false,
|
||||
"flip_aug" : false,
|
||||
"elastic_aug" : false,
|
||||
"blur_aug" : false,
|
||||
"scaling" : false,
|
||||
"scaling" : true,
|
||||
"binarization" : false,
|
||||
"scaling_bluring" : false,
|
||||
"scaling_binarization" : false,
|
||||
"scaling_flip" : false,
|
||||
"rotation": false,
|
||||
"weighted_loss": true,
|
||||
"dir_train": "../train",
|
||||
"dir_eval": "../eval",
|
||||
"dir_output": "../output"
|
||||
"rotation_not_90": false,
|
||||
"continue_training": false,
|
||||
"index_start": 0,
|
||||
"dir_of_start_model": " ",
|
||||
"weighted_loss": false,
|
||||
"is_loss_soft_dice": false,
|
||||
"data_is_provided": false,
|
||||
"dir_train": "/home/vahid/Documents/handwrittens_train/train",
|
||||
"dir_eval": "/home/vahid/Documents/handwrittens_train/eval",
|
||||
"dir_output": "/home/vahid/Documents/handwrittens_train/output"
|
||||
}
|
||||
|
|
208
train.py
208
train.py
|
@ -8,7 +8,8 @@ from sacred import Experiment
|
|||
from models import *
|
||||
from utils import *
|
||||
from metrics import *
|
||||
|
||||
from keras.models import load_model
|
||||
from tqdm import tqdm
|
||||
|
||||
def configuration():
|
||||
keras.backend.clear_session()
|
||||
|
@ -47,7 +48,6 @@ def config_params():
|
|||
# extraction this should be set to false since model should see all image.
|
||||
augmentation=False
|
||||
flip_aug=False # Flip image (augmentation).
|
||||
elastic_aug=False # Elastic transformation (augmentation).
|
||||
blur_aug=False # Blur patches of image (augmentation).
|
||||
scaling=False # Scaling of patches (augmentation) will be imposed if this set to true.
|
||||
binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied.
|
||||
|
@ -55,82 +55,120 @@ def config_params():
|
|||
dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels).
|
||||
dir_output=None # Directory of output where the model should be saved.
|
||||
pretraining=False # Set true to load pretrained weights of resnet50 encoder.
|
||||
weighted_loss=False # Set True if classes are unbalanced and you want to use weighted loss function.
|
||||
scaling_bluring=False
|
||||
rotation: False
|
||||
scaling_binarization=False
|
||||
scaling_flip=False
|
||||
thetha=[10,-10]
|
||||
blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation.
|
||||
scales=[0.9 , 1.1 ] # Scale patches with these scales. Used for augmentation.
|
||||
flip_index=[0,1] # Flip image. Used for augmentation.
|
||||
|
||||
scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation.
|
||||
flip_index=[0,1,-1] # Flip image. Used for augmentation.
|
||||
continue_training = False # If
|
||||
index_start = 0
|
||||
dir_of_start_model = ''
|
||||
is_loss_soft_dice = False
|
||||
weighted_loss = False
|
||||
data_is_provided = False
|
||||
|
||||
@ex.automain
|
||||
def run(n_classes,n_epochs,input_height,
|
||||
input_width,weight_decay,weighted_loss,
|
||||
n_batch,patches,augmentation,flip_aug,blur_aug,scaling, binarization,
|
||||
blur_k,scales,dir_train,
|
||||
index_start,dir_of_start_model,is_loss_soft_dice,
|
||||
n_batch,patches,augmentation,flip_aug
|
||||
,blur_aug,scaling, binarization,
|
||||
blur_k,scales,dir_train,data_is_provided,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
rotation_not_90,thetha,scaling_flip,continue_training,
|
||||
flip_index,dir_eval ,dir_output,pretraining,learning_rate):
|
||||
|
||||
dir_img,dir_seg=get_dirs_or_files(dir_train)
|
||||
dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval)
|
||||
|
||||
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
||||
dir_train_flowing=os.path.join(dir_output,'train')
|
||||
dir_eval_flowing=os.path.join(dir_output,'eval')
|
||||
|
||||
dir_flow_train_imgs=os.path.join(dir_train_flowing,'images')
|
||||
dir_flow_train_labels=os.path.join(dir_train_flowing,'labels')
|
||||
|
||||
dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images')
|
||||
dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels')
|
||||
|
||||
if os.path.isdir(dir_train_flowing):
|
||||
os.system('rm -rf '+dir_train_flowing)
|
||||
os.makedirs(dir_train_flowing)
|
||||
if data_is_provided:
|
||||
dir_train_flowing=os.path.join(dir_output,'train')
|
||||
dir_eval_flowing=os.path.join(dir_output,'eval')
|
||||
|
||||
dir_flow_train_imgs=os.path.join(dir_train_flowing,'images')
|
||||
dir_flow_train_labels=os.path.join(dir_train_flowing,'labels')
|
||||
|
||||
dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images')
|
||||
dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels')
|
||||
|
||||
configuration()
|
||||
|
||||
else:
|
||||
os.makedirs(dir_train_flowing)
|
||||
dir_img,dir_seg=get_dirs_or_files(dir_train)
|
||||
dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval)
|
||||
|
||||
if os.path.isdir(dir_eval_flowing):
|
||||
os.system('rm -rf '+dir_eval_flowing)
|
||||
os.makedirs(dir_eval_flowing)
|
||||
else:
|
||||
os.makedirs(dir_eval_flowing)
|
||||
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
||||
dir_train_flowing=os.path.join(dir_output,'train')
|
||||
dir_eval_flowing=os.path.join(dir_output,'eval')
|
||||
|
||||
dir_flow_train_imgs=os.path.join(dir_train_flowing,'images/')
|
||||
dir_flow_train_labels=os.path.join(dir_train_flowing,'labels/')
|
||||
|
||||
dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images/')
|
||||
dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels/')
|
||||
|
||||
if os.path.isdir(dir_train_flowing):
|
||||
os.system('rm -rf '+dir_train_flowing)
|
||||
os.makedirs(dir_train_flowing)
|
||||
else:
|
||||
os.makedirs(dir_train_flowing)
|
||||
|
||||
if os.path.isdir(dir_eval_flowing):
|
||||
os.system('rm -rf '+dir_eval_flowing)
|
||||
os.makedirs(dir_eval_flowing)
|
||||
else:
|
||||
os.makedirs(dir_eval_flowing)
|
||||
|
||||
|
||||
os.mkdir(dir_flow_train_imgs)
|
||||
os.mkdir(dir_flow_train_labels)
|
||||
|
||||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
|
||||
#set the gpu configuration
|
||||
configuration()
|
||||
|
||||
|
||||
#writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
rotation_not_90,thetha,scaling_flip,
|
||||
augmentation=augmentation,patches=patches)
|
||||
|
||||
provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs,
|
||||
dir_flow_eval_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
rotation_not_90,thetha,scaling_flip,
|
||||
augmentation=False,patches=patches)
|
||||
|
||||
|
||||
os.mkdir(dir_flow_train_imgs)
|
||||
os.mkdir(dir_flow_train_labels)
|
||||
|
||||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
|
||||
|
||||
#set the gpu configuration
|
||||
configuration()
|
||||
|
||||
|
||||
#writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
augmentation=augmentation,patches=patches)
|
||||
|
||||
provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs,
|
||||
dir_flow_eval_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
augmentation=False,patches=patches)
|
||||
|
||||
if weighted_loss:
|
||||
weights=np.zeros(n_classes)
|
||||
for obj in os.listdir(dir_seg):
|
||||
label_obj=cv2.imread(dir_seg+'/'+obj)
|
||||
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
|
||||
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
if data_is_provided:
|
||||
for obj in os.listdir(dir_flow_train_labels):
|
||||
try:
|
||||
label_obj=cv2.imread(dir_flow_train_labels+'/'+obj)
|
||||
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
|
||||
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
|
||||
for obj in os.listdir(dir_seg):
|
||||
try:
|
||||
label_obj=cv2.imread(dir_seg+'/'+obj)
|
||||
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
|
||||
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
weights=1.00/weights
|
||||
|
@ -138,27 +176,35 @@ def run(n_classes,n_epochs,input_height,
|
|||
weights=weights/float(np.sum(weights))
|
||||
weights=weights/float(np.min(weights))
|
||||
weights=weights/float(np.sum(weights))
|
||||
|
||||
|
||||
|
||||
|
||||
#get our model.
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
|
||||
if continue_training:
|
||||
if is_loss_soft_dice:
|
||||
model = load_model (dir_of_start_model, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model (dir_of_start_model, compile = True)
|
||||
else:
|
||||
#get our model.
|
||||
index_start = 0
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
|
||||
if not weighted_loss:
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
|
||||
mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5',
|
||||
save_weights_only=True, period=1)
|
||||
|
||||
|
||||
#generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch,
|
||||
|
@ -166,20 +212,20 @@ def run(n_classes,n_epochs,input_height,
|
|||
val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch,
|
||||
input_height=input_height, input_width=input_width,n_classes=n_classes )
|
||||
|
||||
for i in tqdm(range(index_start, n_epochs+index_start)):
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
model.save(dir_output+'/'+'model_'+str(i)+'.h5')
|
||||
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=n_epochs)
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
#os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
|
||||
os.system('rm -rf '+dir_train_flowing)
|
||||
os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
model.save(dir_output+'/'+'model'+'.h5')
|
||||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
|
||||
|
||||
|
||||
|
|
263
utils.py
263
utils.py
|
@ -6,7 +6,8 @@ from scipy.ndimage.interpolation import map_coordinates
|
|||
from scipy.ndimage.filters import gaussian_filter
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
|
||||
import imutils
|
||||
import math
|
||||
|
||||
|
||||
|
||||
|
@ -19,6 +20,79 @@ def bluring(img_in,kind):
|
|||
img_blur=cv2.blur(img_in,(5,5))
|
||||
return img_blur
|
||||
|
||||
def elastic_transform(image, alpha, sigma,seedj, random_state=None):
|
||||
|
||||
"""Elastic deformation of images as described in [Simard2003]_.
|
||||
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
||||
Convolutional Neural Networks applied to Visual Document Analysis", in
|
||||
Proc. of the International Conference on Document Analysis and
|
||||
Recognition, 2003.
|
||||
"""
|
||||
if random_state is None:
|
||||
random_state = np.random.RandomState(seedj)
|
||||
|
||||
shape = image.shape
|
||||
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
|
||||
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
|
||||
dz = np.zeros_like(dx)
|
||||
|
||||
x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
|
||||
indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)), np.reshape(z, (-1, 1))
|
||||
|
||||
distored_image = map_coordinates(image, indices, order=1, mode='reflect')
|
||||
return distored_image.reshape(image.shape)
|
||||
|
||||
def rotation_90(img):
|
||||
img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2]))
|
||||
img_rot[:,:,0]=img[:,:,0].T
|
||||
img_rot[:,:,1]=img[:,:,1].T
|
||||
img_rot[:,:,2]=img[:,:,2].T
|
||||
return img_rot
|
||||
|
||||
def rotatedRectWithMaxArea(w, h, angle):
|
||||
"""
|
||||
Given a rectangle of size wxh that has been rotated by 'angle' (in
|
||||
radians), computes the width and height of the largest possible
|
||||
axis-aligned rectangle (maximal area) within the rotated rectangle.
|
||||
"""
|
||||
if w <= 0 or h <= 0:
|
||||
return 0,0
|
||||
|
||||
width_is_longer = w >= h
|
||||
side_long, side_short = (w,h) if width_is_longer else (h,w)
|
||||
|
||||
# since the solutions for angle, -angle and 180-angle are all the same,
|
||||
# if suffices to look at the first quadrant and the absolute values of sin,cos:
|
||||
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
|
||||
if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10:
|
||||
# half constrained case: two crop corners touch the longer side,
|
||||
# the other two corners are on the mid-line parallel to the longer line
|
||||
x = 0.5*side_short
|
||||
wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a)
|
||||
else:
|
||||
# fully constrained case: crop touches all 4 sides
|
||||
cos_2a = cos_a*cos_a - sin_a*sin_a
|
||||
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a
|
||||
|
||||
return wr,hr
|
||||
|
||||
def rotate_max_area(image,rotated, rotated_label,angle):
|
||||
""" image: cv2 image matrix object
|
||||
angle: in degree
|
||||
"""
|
||||
wr, hr = rotatedRectWithMaxArea(image.shape[1], image.shape[0],
|
||||
math.radians(angle))
|
||||
h, w, _ = rotated.shape
|
||||
y1 = h//2 - int(hr/2)
|
||||
y2 = y1 + int(hr)
|
||||
x1 = w//2 - int(wr/2)
|
||||
x2 = x1 + int(wr)
|
||||
return rotated[y1:y2, x1:x2],rotated_label[y1:y2, x1:x2]
|
||||
def rotation_not_90_func(img,label,thetha):
|
||||
rotated=imutils.rotate(img,thetha)
|
||||
rotated_label=imutils.rotate(label,thetha)
|
||||
return rotate_max_area(img, rotated,rotated_label,thetha)
|
||||
|
||||
def color_images(seg, n_classes):
|
||||
ann_u=range(n_classes)
|
||||
if len(np.shape(seg))==3:
|
||||
|
@ -65,7 +139,7 @@ def IoU(Yi,y_predi):
|
|||
return mIoU
|
||||
def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_classes):
|
||||
c = 0
|
||||
n = os.listdir(img_folder) #List of training images
|
||||
n = [f for f in os.listdir(img_folder) if not f.startswith('.')]# os.listdir(img_folder) #List of training images
|
||||
random.shuffle(n)
|
||||
while True:
|
||||
img = np.zeros((batch_size, input_height, input_width, 3)).astype('float')
|
||||
|
@ -73,18 +147,26 @@ def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_cla
|
|||
|
||||
for i in range(c, c+batch_size): #initially from 0 to 16, c = 0.
|
||||
#print(img_folder+'/'+n[i])
|
||||
filename=n[i].split('.')[0]
|
||||
train_img = cv2.imread(img_folder+'/'+n[i])/255.
|
||||
train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize
|
||||
|
||||
img[i-c] = train_img #add to array - img[0], img[1], and so on.
|
||||
train_mask = cv2.imread(mask_folder+'/'+filename+'.png')
|
||||
#print(mask_folder+'/'+filename+'.png')
|
||||
#print(train_mask.shape)
|
||||
train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes)
|
||||
#train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
||||
|
||||
mask[i-c] = train_mask
|
||||
|
||||
try:
|
||||
filename=n[i].split('.')[0]
|
||||
|
||||
train_img = cv2.imread(img_folder+'/'+n[i])/255.
|
||||
train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize
|
||||
|
||||
img[i-c] = train_img #add to array - img[0], img[1], and so on.
|
||||
train_mask = cv2.imread(mask_folder+'/'+filename+'.png')
|
||||
#print(mask_folder+'/'+filename+'.png')
|
||||
#print(train_mask.shape)
|
||||
train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes)
|
||||
#train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
||||
|
||||
mask[i-c] = train_mask
|
||||
except:
|
||||
img[i-c] = np.ones((input_height, input_width, 3)).astype('float')
|
||||
mask[i-c] = np.zeros((input_height, input_width, n_classes)).astype('float')
|
||||
|
||||
|
||||
|
||||
c+=batch_size
|
||||
if(c+batch_size>=len(os.listdir(img_folder))):
|
||||
|
@ -104,16 +186,10 @@ def otsu_copy(img):
|
|||
img_r[:,:,1]=threshold1
|
||||
img_r[:,:,2]=threshold1
|
||||
return img_r
|
||||
|
||||
def rotation_90(img):
|
||||
img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2]))
|
||||
img_rot[:,:,0]=img[:,:,0].T
|
||||
img_rot[:,:,1]=img[:,:,1].T
|
||||
img_rot[:,:,2]=img[:,:,2].T
|
||||
return img_rot
|
||||
|
||||
def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer):
|
||||
|
||||
if img.shape[0]<height or img.shape[1]<width:
|
||||
img,label=do_padding(img,label,height,width)
|
||||
|
||||
img_h=img.shape[0]
|
||||
img_w=img.shape[1]
|
||||
|
@ -151,12 +227,39 @@ def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer):
|
|||
cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch )
|
||||
cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch )
|
||||
indexer+=1
|
||||
|
||||
return indexer
|
||||
|
||||
|
||||
def do_padding(img,label,height,width):
|
||||
|
||||
height_new=img.shape[0]
|
||||
width_new=img.shape[1]
|
||||
|
||||
h_start=0
|
||||
w_start=0
|
||||
|
||||
if img.shape[0]<height:
|
||||
h_start=int( abs(height-img.shape[0])/2. )
|
||||
height_new=height
|
||||
|
||||
def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,scaler):
|
||||
|
||||
if img.shape[1]<width:
|
||||
w_start=int( abs(width-img.shape[1])/2. )
|
||||
width_new=width
|
||||
|
||||
img_new=np.ones((height_new,width_new,img.shape[2])).astype(float)*255
|
||||
label_new=np.zeros((height_new,width_new,label.shape[2])).astype(float)
|
||||
|
||||
img_new[h_start:h_start+img.shape[0],w_start:w_start+img.shape[1],:]=np.copy(img[:,:,:])
|
||||
label_new[h_start:h_start+label.shape[0],w_start:w_start+label.shape[1],:]=np.copy(label[:,:,:])
|
||||
|
||||
return img_new,label_new
|
||||
|
||||
|
||||
def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,n_patches,scaler):
|
||||
|
||||
|
||||
if img.shape[0]<height or img.shape[1]<width:
|
||||
img,label=do_padding(img,label,height,width)
|
||||
|
||||
img_h=img.shape[0]
|
||||
img_w=img.shape[1]
|
||||
|
@ -204,6 +307,58 @@ def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,sca
|
|||
|
||||
return indexer
|
||||
|
||||
def get_patches_num_scale_new(dir_img_f,dir_seg_f,img,label,height,width,indexer,scaler):
|
||||
img=resize_image(img,int(img.shape[0]*scaler),int(img.shape[1]*scaler))
|
||||
label=resize_image(label,int(label.shape[0]*scaler),int(label.shape[1]*scaler))
|
||||
|
||||
if img.shape[0]<height or img.shape[1]<width:
|
||||
img,label=do_padding(img,label,height,width)
|
||||
|
||||
img_h=img.shape[0]
|
||||
img_w=img.shape[1]
|
||||
|
||||
height_scale=int(height*1)
|
||||
width_scale=int(width*1)
|
||||
|
||||
|
||||
nxf=img_w/float(width_scale)
|
||||
nyf=img_h/float(height_scale)
|
||||
|
||||
if nxf>int(nxf):
|
||||
nxf=int(nxf)+1
|
||||
if nyf>int(nyf):
|
||||
nyf=int(nyf)+1
|
||||
|
||||
nxf=int(nxf)
|
||||
nyf=int(nyf)
|
||||
|
||||
for i in range(nxf):
|
||||
for j in range(nyf):
|
||||
index_x_d=i*width_scale
|
||||
index_x_u=(i+1)*width_scale
|
||||
|
||||
index_y_d=j*height_scale
|
||||
index_y_u=(j+1)*height_scale
|
||||
|
||||
if index_x_u>img_w:
|
||||
index_x_u=img_w
|
||||
index_x_d=img_w-width_scale
|
||||
if index_y_u>img_h:
|
||||
index_y_u=img_h
|
||||
index_y_d=img_h-height_scale
|
||||
|
||||
|
||||
img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:]
|
||||
label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:]
|
||||
|
||||
#img_patch=resize_image(img_patch,height,width)
|
||||
#label_patch=resize_image(label_patch,height,width)
|
||||
|
||||
cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch )
|
||||
cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch )
|
||||
indexer+=1
|
||||
|
||||
return indexer
|
||||
|
||||
|
||||
def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
||||
|
@ -211,6 +366,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
rotation_not_90,thetha,scaling_flip,
|
||||
augmentation=False,patches=False):
|
||||
|
||||
imgs_cv_train=np.array(os.listdir(dir_img))
|
||||
|
@ -219,24 +375,12 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
indexer=0
|
||||
for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)):
|
||||
img_name=im.split('.')[0]
|
||||
|
||||
if not patches:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) )
|
||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) )
|
||||
indexer+=1
|
||||
|
||||
if augmentation:
|
||||
if rotation:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png',
|
||||
rotation_90( resize_image(cv2.imread(dir_img+'/'+im),
|
||||
input_height,input_width) ) )
|
||||
|
||||
|
||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png',
|
||||
rotation_90 ( resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width) ) )
|
||||
indexer+=1
|
||||
|
||||
if flip_aug:
|
||||
for f_i in flip_index:
|
||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png',
|
||||
|
@ -270,10 +414,10 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
|
||||
|
||||
if patches:
|
||||
|
||||
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer)
|
||||
cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer)
|
||||
|
||||
if augmentation:
|
||||
|
||||
|
@ -284,29 +428,37 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
rotation_90( cv2.imread(dir_img+'/'+im) ),
|
||||
rotation_90( cv2.imread(dir_seg+'/'+img_name+'.png') ),
|
||||
input_height,input_width,indexer=indexer)
|
||||
|
||||
if rotation_not_90:
|
||||
|
||||
for thetha_i in thetha:
|
||||
img_max_rotated,label_max_rotated=rotation_not_90_func(cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'),thetha_i)
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
img_max_rotated,
|
||||
label_max_rotated,
|
||||
input_height,input_width,indexer=indexer)
|
||||
if flip_aug:
|
||||
for f_i in flip_index:
|
||||
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
cv2.flip( cv2.imread(dir_img+'/'+im) , f_i),
|
||||
cv2.flip( cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i),
|
||||
input_height,input_width,indexer=indexer)
|
||||
if blur_aug:
|
||||
for blur_i in blur_k:
|
||||
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
bluring( cv2.imread(dir_img+'/'+im) , blur_i),
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer)
|
||||
|
||||
input_height,input_width,indexer=indexer)
|
||||
|
||||
|
||||
if scaling:
|
||||
for sc_ind in scales:
|
||||
indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
cv2.imread(dir_img+'/'+im) ,
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
cv2.imread(dir_img+'/'+im) ,
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
if binarization:
|
||||
|
||||
indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
otsu_copy( cv2.imread(dir_img+'/'+im)),
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
|
@ -317,17 +469,26 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
|||
if scaling_bluring:
|
||||
for sc_ind in scales:
|
||||
for blur_i in blur_k:
|
||||
indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
bluring( cv2.imread(dir_img+'/'+im) , blur_i) ,
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png') ,
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
|
||||
if scaling_binarization:
|
||||
for sc_ind in scales:
|
||||
indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
otsu_copy( cv2.imread(dir_img+'/'+im)) ,
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
otsu_copy( cv2.imread(dir_img+'/'+im)) ,
|
||||
cv2.imread(dir_seg+'/'+img_name+'.png'),
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
|
||||
if scaling_flip:
|
||||
for sc_ind in scales:
|
||||
for f_i in flip_index:
|
||||
indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels,
|
||||
cv2.flip( cv2.imread(dir_img+'/'+im) , f_i) ,
|
||||
cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i) ,
|
||||
input_height,input_width,indexer=indexer,scaler=sc_ind)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue