From 5fb7552dbe6534daf29ed06b3b0ee8858a5ad0ea Mon Sep 17 00:00:00 2001 From: vahid Date: Tue, 22 Jun 2021 14:20:51 -0400 Subject: [PATCH 1/2] first updates, padding, rotations --- config_params.json | 22 ++-- train.py | 185 ++++++++++++++++--------------- utils.py | 265 ++++++++++++++++++++++++++++++++++++--------- 3 files changed, 320 insertions(+), 152 deletions(-) diff --git a/config_params.json b/config_params.json index 5066444..d8f1ac5 100644 --- a/config_params.json +++ b/config_params.json @@ -1,24 +1,24 @@ { - "n_classes" : 2, - "n_epochs" : 2, + "n_classes" : 3, + "n_epochs" : 1, "input_height" : 448, - "input_width" : 896, + "input_width" : 672, "weight_decay" : 1e-6, - "n_batch" : 1, + "n_batch" : 2, "learning_rate": 1e-4, "patches" : true, "pretraining" : true, - "augmentation" : false, + "augmentation" : true, "flip_aug" : false, - "elastic_aug" : false, - "blur_aug" : false, + "blur_aug" : true, "scaling" : false, "binarization" : false, "scaling_bluring" : false, "scaling_binarization" : false, + "scaling_flip" : false, "rotation": false, - "weighted_loss": true, - "dir_train": "../train", - "dir_eval": "../eval", - "dir_output": "../output" + "rotation_not_90": false, + "dir_train": "/home/vahid/Documents/handwrittens_train/train", + "dir_eval": "/home/vahid/Documents/handwrittens_train/eval", + "dir_output": "/home/vahid/Documents/handwrittens_train/output" } diff --git a/train.py b/train.py index baeb847..c256d83 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,7 @@ from sacred import Experiment from models import * from utils import * from metrics import * - +from keras.models import load_model def configuration(): keras.backend.clear_session() @@ -47,7 +47,6 @@ def config_params(): # extraction this should be set to false since model should see all image. augmentation=False flip_aug=False # Flip image (augmentation). - elastic_aug=False # Elastic transformation (augmentation). blur_aug=False # Blur patches of image (augmentation). scaling=False # Scaling of patches (augmentation) will be imposed if this set to true. binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied. @@ -55,110 +54,116 @@ def config_params(): dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels). dir_output=None # Directory of output where the model should be saved. pretraining=False # Set true to load pretrained weights of resnet50 encoder. - weighted_loss=False # Set True if classes are unbalanced and you want to use weighted loss function. scaling_bluring=False - rotation: False scaling_binarization=False + scaling_flip=False + thetha=[10,-10] blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. - scales=[0.9 , 1.1 ] # Scale patches with these scales. Used for augmentation. - flip_index=[0,1] # Flip image. Used for augmentation. + scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation. + flip_index=[0,1,-1] # Flip image. Used for augmentation. @ex.automain def run(n_classes,n_epochs,input_height, - input_width,weight_decay,weighted_loss, - n_batch,patches,augmentation,flip_aug,blur_aug,scaling, binarization, + input_width,weight_decay, + n_batch,patches,augmentation,flip_aug + ,blur_aug,scaling, binarization, blur_k,scales,dir_train, scaling_bluring,scaling_binarization,rotation, + rotation_not_90,thetha,scaling_flip, flip_index,dir_eval ,dir_output,pretraining,learning_rate): - dir_img,dir_seg=get_dirs_or_files(dir_train) - dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval) - - # make first a directory in output for both training and evaluations in order to flow data from these directories. - dir_train_flowing=os.path.join(dir_output,'train') - dir_eval_flowing=os.path.join(dir_output,'eval') - - dir_flow_train_imgs=os.path.join(dir_train_flowing,'images') - dir_flow_train_labels=os.path.join(dir_train_flowing,'labels') + data_is_provided = False - dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images') - dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels') - - if os.path.isdir(dir_train_flowing): - os.system('rm -rf '+dir_train_flowing) - os.makedirs(dir_train_flowing) - else: - os.makedirs(dir_train_flowing) + if data_is_provided: + dir_train_flowing=os.path.join(dir_output,'train') + dir_eval_flowing=os.path.join(dir_output,'eval') + + dir_flow_train_imgs=os.path.join(dir_train_flowing,'images') + dir_flow_train_labels=os.path.join(dir_train_flowing,'labels') + + dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images') + dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels') + + configuration() - if os.path.isdir(dir_eval_flowing): - os.system('rm -rf '+dir_eval_flowing) - os.makedirs(dir_eval_flowing) else: - os.makedirs(dir_eval_flowing) + dir_img,dir_seg=get_dirs_or_files(dir_train) + dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval) - - os.mkdir(dir_flow_train_imgs) - os.mkdir(dir_flow_train_labels) - - os.mkdir(dir_flow_eval_imgs) - os.mkdir(dir_flow_eval_labels) - - - - #set the gpu configuration - configuration() - - - #writing patches into a sub-folder in order to be flowed from directory. - provide_patches(dir_img,dir_seg,dir_flow_train_imgs, - dir_flow_train_labels, - input_height,input_width,blur_k,blur_aug, - flip_aug,binarization,scaling,scales,flip_index, - scaling_bluring,scaling_binarization,rotation, - augmentation=augmentation,patches=patches) - - provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs, - dir_flow_eval_labels, - input_height,input_width,blur_k,blur_aug, - flip_aug,binarization,scaling,scales,flip_index, - scaling_bluring,scaling_binarization,rotation, - augmentation=False,patches=patches) + # make first a directory in output for both training and evaluations in order to flow data from these directories. + dir_train_flowing=os.path.join(dir_output,'train') + dir_eval_flowing=os.path.join(dir_output,'eval') + + dir_flow_train_imgs=os.path.join(dir_train_flowing,'images/') + dir_flow_train_labels=os.path.join(dir_train_flowing,'labels/') - if weighted_loss: - weights=np.zeros(n_classes) - for obj in os.listdir(dir_seg): - label_obj=cv2.imread(dir_seg+'/'+obj) - label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) - weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images/') + dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels/') + + if os.path.isdir(dir_train_flowing): + os.system('rm -rf '+dir_train_flowing) + os.makedirs(dir_train_flowing) + else: + os.makedirs(dir_train_flowing) + + if os.path.isdir(dir_eval_flowing): + os.system('rm -rf '+dir_eval_flowing) + os.makedirs(dir_eval_flowing) + else: + os.makedirs(dir_eval_flowing) - weights=1.00/weights + os.mkdir(dir_flow_train_imgs) + os.mkdir(dir_flow_train_labels) - weights=weights/float(np.sum(weights)) - weights=weights/float(np.min(weights)) - weights=weights/float(np.sum(weights)) - - - + os.mkdir(dir_flow_eval_imgs) + os.mkdir(dir_flow_eval_labels) + + + #set the gpu configuration + configuration() + + + #writing patches into a sub-folder in order to be flowed from directory. + provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + dir_flow_train_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + rotation_not_90,thetha,scaling_flip, + augmentation=augmentation,patches=patches) + + provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs, + dir_flow_eval_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + rotation_not_90,thetha,scaling_flip, + augmentation=False,patches=patches) - #get our model. - model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + + continue_train = False + + if continue_train: + model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5' + model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) + index_start = 1 + else: + #get our model. + index_start = 0 + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) #if you want to see the model structure just uncomment model summary. #model.summary() - if not weighted_loss: - model.compile(loss='categorical_crossentropy', - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - if weighted_loss: - model.compile(loss=weighted_categorical_crossentropy(weights), - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - - mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', - save_weights_only=True, period=1) - + + #model.compile(loss='categorical_crossentropy', + #optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + + model.compile(loss=soft_dice_loss, + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) #generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, @@ -166,20 +171,20 @@ def run(n_classes,n_epochs,input_height, val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, input_height=input_height, input_width=input_width,n_classes=n_classes ) + for i in range(index_start, n_epochs+index_start): + model.fit_generator( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, + validation_data=val_gen, + validation_steps=1, + epochs=1) + model.save(dir_output+'/'+'model_'+str(i)+'.h5') - model.fit_generator( - train_gen, - steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, - validation_data=val_gen, - validation_steps=1, - epochs=n_epochs) - - os.system('rm -rf '+dir_train_flowing) os.system('rm -rf '+dir_eval_flowing) - model.save(dir_output+'/'+'model'+'.h5') + #model.save(dir_output+'/'+'model'+'.h5') diff --git a/utils.py b/utils.py index afdc9e5..a77444e 100644 --- a/utils.py +++ b/utils.py @@ -6,7 +6,8 @@ from scipy.ndimage.interpolation import map_coordinates from scipy.ndimage.filters import gaussian_filter import random from tqdm import tqdm - +import imutils +import math @@ -19,6 +20,79 @@ def bluring(img_in,kind): img_blur=cv2.blur(img_in,(5,5)) return img_blur +def elastic_transform(image, alpha, sigma,seedj, random_state=None): + + """Elastic deformation of images as described in [Simard2003]_. + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + """ + if random_state is None: + random_state = np.random.RandomState(seedj) + + shape = image.shape + dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha + dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha + dz = np.zeros_like(dx) + + x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) + indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)), np.reshape(z, (-1, 1)) + + distored_image = map_coordinates(image, indices, order=1, mode='reflect') + return distored_image.reshape(image.shape) + +def rotation_90(img): + img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2])) + img_rot[:,:,0]=img[:,:,0].T + img_rot[:,:,1]=img[:,:,1].T + img_rot[:,:,2]=img[:,:,2].T + return img_rot + +def rotatedRectWithMaxArea(w, h, angle): + """ + Given a rectangle of size wxh that has been rotated by 'angle' (in + radians), computes the width and height of the largest possible + axis-aligned rectangle (maximal area) within the rotated rectangle. + """ + if w <= 0 or h <= 0: + return 0,0 + + width_is_longer = w >= h + side_long, side_short = (w,h) if width_is_longer else (h,w) + + # since the solutions for angle, -angle and 180-angle are all the same, + # if suffices to look at the first quadrant and the absolute values of sin,cos: + sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) + if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10: + # half constrained case: two crop corners touch the longer side, + # the other two corners are on the mid-line parallel to the longer line + x = 0.5*side_short + wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a) + else: + # fully constrained case: crop touches all 4 sides + cos_2a = cos_a*cos_a - sin_a*sin_a + wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a + + return wr,hr + +def rotate_max_area(image,rotated, rotated_label,angle): + """ image: cv2 image matrix object + angle: in degree + """ + wr, hr = rotatedRectWithMaxArea(image.shape[1], image.shape[0], + math.radians(angle)) + h, w, _ = rotated.shape + y1 = h//2 - int(hr/2) + y2 = y1 + int(hr) + x1 = w//2 - int(wr/2) + x2 = x1 + int(wr) + return rotated[y1:y2, x1:x2],rotated_label[y1:y2, x1:x2] +def rotation_not_90_func(img,label,thetha): + rotated=imutils.rotate(img,thetha) + rotated_label=imutils.rotate(label,thetha) + return rotate_max_area(img, rotated,rotated_label,thetha) + def color_images(seg, n_classes): ann_u=range(n_classes) if len(np.shape(seg))==3: @@ -65,7 +139,7 @@ def IoU(Yi,y_predi): return mIoU def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_classes): c = 0 - n = os.listdir(img_folder) #List of training images + n = [f for f in os.listdir(img_folder) if not f.startswith('.')]# os.listdir(img_folder) #List of training images random.shuffle(n) while True: img = np.zeros((batch_size, input_height, input_width, 3)).astype('float') @@ -73,18 +147,26 @@ def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_cla for i in range(c, c+batch_size): #initially from 0 to 16, c = 0. #print(img_folder+'/'+n[i]) - filename=n[i].split('.')[0] - train_img = cv2.imread(img_folder+'/'+n[i])/255. - train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize - - img[i-c] = train_img #add to array - img[0], img[1], and so on. - train_mask = cv2.imread(mask_folder+'/'+filename+'.png') - #print(mask_folder+'/'+filename+'.png') - #print(train_mask.shape) - train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes) - #train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] - - mask[i-c] = train_mask + + try: + filename=n[i].split('.')[0] + + train_img = cv2.imread(img_folder+'/'+n[i])/255. + train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize + + img[i-c] = train_img #add to array - img[0], img[1], and so on. + train_mask = cv2.imread(mask_folder+'/'+filename+'.png') + #print(mask_folder+'/'+filename+'.png') + #print(train_mask.shape) + train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes) + #train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] + + mask[i-c] = train_mask + except: + img[i-c] = np.ones((input_height, input_width, 3)).astype('float') + mask[i-c] = np.zeros((input_height, input_width, n_classes)).astype('float') + + c+=batch_size if(c+batch_size>=len(os.listdir(img_folder))): @@ -104,16 +186,10 @@ def otsu_copy(img): img_r[:,:,1]=threshold1 img_r[:,:,2]=threshold1 return img_r - -def rotation_90(img): - img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2])) - img_rot[:,:,0]=img[:,:,0].T - img_rot[:,:,1]=img[:,:,1].T - img_rot[:,:,2]=img[:,:,2].T - return img_rot - def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer): + if img.shape[0]int(nxf): + nxf=int(nxf)+1 + if nyf>int(nyf): + nyf=int(nyf)+1 + + nxf=int(nxf) + nyf=int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d=i*width_scale + index_x_u=(i+1)*width_scale + + index_y_d=j*height_scale + index_y_u=(j+1)*height_scale + + if index_x_u>img_w: + index_x_u=img_w + index_x_d=img_w-width_scale + if index_y_u>img_h: + index_y_u=img_h + index_y_d=img_h-height_scale + + + img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] + label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] + + #img_patch=resize_image(img_patch,height,width) + #label_patch=resize_image(label_patch,height,width) + + cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) + cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) + indexer+=1 + + return indexer def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, @@ -211,6 +366,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, input_height,input_width,blur_k,blur_aug, flip_aug,binarization,scaling,scales,flip_index, scaling_bluring,scaling_binarization,rotation, + rotation_not_90,thetha,scaling_flip, augmentation=False,patches=False): imgs_cv_train=np.array(os.listdir(dir_img)) @@ -218,25 +374,15 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, indexer=0 for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): + #print(im, seg_i) img_name=im.split('.')[0] - + print(img_name,'img_name') if not patches: cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) ) indexer+=1 if augmentation: - if rotation: - cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', - rotation_90( resize_image(cv2.imread(dir_img+'/'+im), - input_height,input_width) ) ) - - - cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', - rotation_90 ( resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width) ) ) - indexer+=1 - if flip_aug: for f_i in flip_index: cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', @@ -270,10 +416,10 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, if patches: - + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, - cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer) + cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) if augmentation: @@ -284,29 +430,37 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, rotation_90( cv2.imread(dir_img+'/'+im) ), rotation_90( cv2.imread(dir_seg+'/'+img_name+'.png') ), input_height,input_width,indexer=indexer) + + if rotation_not_90: + + for thetha_i in thetha: + img_max_rotated,label_max_rotated=rotation_not_90_func(cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'),thetha_i) + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + img_max_rotated, + label_max_rotated, + input_height,input_width,indexer=indexer) if flip_aug: for f_i in flip_index: - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, cv2.flip( cv2.imread(dir_img+'/'+im) , f_i), cv2.flip( cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i), input_height,input_width,indexer=indexer) if blur_aug: for blur_i in blur_k: + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, bluring( cv2.imread(dir_img+'/'+im) , blur_i), cv2.imread(dir_seg+'/'+img_name+'.png'), - input_height,input_width,indexer=indexer) - + input_height,input_width,indexer=indexer) + if scaling: for sc_ind in scales: - indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, - cv2.imread(dir_img+'/'+im) , - cv2.imread(dir_seg+'/'+img_name+'.png'), + indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, + cv2.imread(dir_img+'/'+im) , + cv2.imread(dir_seg+'/'+img_name+'.png'), input_height,input_width,indexer=indexer,scaler=sc_ind) if binarization: - indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, otsu_copy( cv2.imread(dir_img+'/'+im)), cv2.imread(dir_seg+'/'+img_name+'.png'), @@ -317,17 +471,26 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, if scaling_bluring: for sc_ind in scales: for blur_i in blur_k: - indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, bluring( cv2.imread(dir_img+'/'+im) , blur_i) , cv2.imread(dir_seg+'/'+img_name+'.png') , input_height,input_width,indexer=indexer,scaler=sc_ind) if scaling_binarization: for sc_ind in scales: - indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, - otsu_copy( cv2.imread(dir_img+'/'+im)) , - cv2.imread(dir_seg+'/'+img_name+'.png'), + indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, + otsu_copy( cv2.imread(dir_img+'/'+im)) , + cv2.imread(dir_seg+'/'+img_name+'.png'), input_height,input_width,indexer=indexer,scaler=sc_ind) + + if scaling_flip: + for sc_ind in scales: + for f_i in flip_index: + indexer=get_patches_num_scale_new(dir_flow_train_imgs,dir_flow_train_labels, + cv2.flip( cv2.imread(dir_img+'/'+im) , f_i) , + cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i) , + input_height,input_width,indexer=indexer,scaler=sc_ind) + From 4bea9fd5354dfe1f78f9b84d20a97714b54ab37d Mon Sep 17 00:00:00 2001 From: vahid Date: Tue, 22 Jun 2021 18:47:59 -0400 Subject: [PATCH 2/2] continue training, losses and etc --- config_params.json | 14 ++++++--- train.py | 77 +++++++++++++++++++++++++++++++++++----------- utils.py | 2 -- 3 files changed, 69 insertions(+), 24 deletions(-) diff --git a/config_params.json b/config_params.json index d8f1ac5..eaa50e1 100644 --- a/config_params.json +++ b/config_params.json @@ -1,6 +1,6 @@ { "n_classes" : 3, - "n_epochs" : 1, + "n_epochs" : 2, "input_height" : 448, "input_width" : 672, "weight_decay" : 1e-6, @@ -8,16 +8,22 @@ "learning_rate": 1e-4, "patches" : true, "pretraining" : true, - "augmentation" : true, + "augmentation" : false, "flip_aug" : false, - "blur_aug" : true, - "scaling" : false, + "blur_aug" : false, + "scaling" : true, "binarization" : false, "scaling_bluring" : false, "scaling_binarization" : false, "scaling_flip" : false, "rotation": false, "rotation_not_90": false, + "continue_training": false, + "index_start": 0, + "dir_of_start_model": " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, "dir_train": "/home/vahid/Documents/handwrittens_train/train", "dir_eval": "/home/vahid/Documents/handwrittens_train/eval", "dir_output": "/home/vahid/Documents/handwrittens_train/output" diff --git a/train.py b/train.py index c256d83..0cc5ef3 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ from models import * from utils import * from metrics import * from keras.models import load_model +from tqdm import tqdm def configuration(): keras.backend.clear_session() @@ -61,19 +62,24 @@ def config_params(): blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation. flip_index=[0,1,-1] # Flip image. Used for augmentation. - + continue_training = False # If + index_start = 0 + dir_of_start_model = '' + is_loss_soft_dice = False + weighted_loss = False + data_is_provided = False @ex.automain def run(n_classes,n_epochs,input_height, - input_width,weight_decay, + input_width,weight_decay,weighted_loss, + index_start,dir_of_start_model,is_loss_soft_dice, n_batch,patches,augmentation,flip_aug ,blur_aug,scaling, binarization, - blur_k,scales,dir_train, + blur_k,scales,dir_train,data_is_provided, scaling_bluring,scaling_binarization,rotation, - rotation_not_90,thetha,scaling_flip, + rotation_not_90,thetha,scaling_flip,continue_training, flip_index,dir_eval ,dir_output,pretraining,learning_rate): - data_is_provided = False if data_is_provided: dir_train_flowing=os.path.join(dir_output,'train') @@ -143,12 +149,43 @@ def run(n_classes,n_epochs,input_height, augmentation=False,patches=patches) - continue_train = False + + if weighted_loss: + weights=np.zeros(n_classes) + if data_is_provided: + for obj in os.listdir(dir_flow_train_labels): + try: + label_obj=cv2.imread(dir_flow_train_labels+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + else: + + for obj in os.listdir(dir_seg): + try: + label_obj=cv2.imread(dir_seg+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + - if continue_train: - model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5' - model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) - index_start = 1 + weights=1.00/weights + + weights=weights/float(np.sum(weights)) + weights=weights/float(np.min(weights)) + weights=weights/float(np.sum(weights)) + + + + if continue_training: + if is_loss_soft_dice: + model = load_model (dir_of_start_model, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model (dir_of_start_model, compile = True) else: #get our model. index_start = 0 @@ -158,12 +195,16 @@ def run(n_classes,n_epochs,input_height, #model.summary() + if not is_loss_soft_dice and not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if is_loss_soft_dice: + model.compile(loss=soft_dice_loss, + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - #model.compile(loss='categorical_crossentropy', - #optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - - model.compile(loss=soft_dice_loss, - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) #generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, @@ -171,7 +212,7 @@ def run(n_classes,n_epochs,input_height, val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, input_height=input_height, input_width=input_width,n_classes=n_classes ) - for i in range(index_start, n_epochs+index_start): + for i in tqdm(range(index_start, n_epochs+index_start)): model.fit_generator( train_gen, steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, @@ -181,8 +222,8 @@ def run(n_classes,n_epochs,input_height, model.save(dir_output+'/'+'model_'+str(i)+'.h5') - os.system('rm -rf '+dir_train_flowing) - os.system('rm -rf '+dir_eval_flowing) + #os.system('rm -rf '+dir_train_flowing) + #os.system('rm -rf '+dir_eval_flowing) #model.save(dir_output+'/'+'model'+'.h5') diff --git a/utils.py b/utils.py index a77444e..19ab46e 100644 --- a/utils.py +++ b/utils.py @@ -374,9 +374,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, indexer=0 for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): - #print(im, seg_i) img_name=im.split('.')[0] - print(img_name,'img_name') if not patches: cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) )