From 038d776c2cbd6662c51d5f0cb68705134ef1bc56 Mon Sep 17 00:00:00 2001 From: b-vr103 Date: Thu, 5 Dec 2019 14:05:07 +0100 Subject: [PATCH] add files needed for training --- README | 23 ++++ config_params.json | 24 ++++ train.py | 192 ++++++++++++++++++++++++++ utils.py | 336 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 575 insertions(+) create mode 100644 README create mode 100644 config_params.json create mode 100644 train.py create mode 100644 utils.py diff --git a/README b/README new file mode 100644 index 0000000..7d8d790 --- /dev/null +++ b/README @@ -0,0 +1,23 @@ +how to train: + just run: python train.py with config_params.json + + +format of ground truth: + + Lables for each pixel is identified by a number . So if you have a binary case n_classes should be set to 2 and labels should be 0 and 1 for each class and pixel. + In the case of multiclass just set n_classes to the number of classes you have and the try to produce the labels by pixels from 0 , 1 ,2 .., n_classes-1. + The labels format should be png. + + If you have an image label for binary case it should look like this: + + Label: [ [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]], [[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ,[[1 0 0 1], [1 0 0 1] ,[1 0 0 1]] ] this means that you have an image by 3*4*3 and pixel[0,0] belongs to class 1 and pixel[0,1] to class 0. + +traing , evaluation and output: + train and evaluation folder should have subfolder of images and labels. + And output folder should be free folder which the output model will be written there. + +patches: + + if you want to train your model with patches, the height and width of patches should be defined and also number of batchs (how many patches should be seen by model by each iteration). + In the case that model should see the image once, like page extraction, the patches should be set to false. + diff --git a/config_params.json b/config_params.json new file mode 100644 index 0000000..52db6db --- /dev/null +++ b/config_params.json @@ -0,0 +1,24 @@ +{ + "n_classes" : 2, + "n_epochs" : 2, + "input_height" : 448, + "input_width" : 896, + "weight_decay" : 1e-6, + "n_batch" : 1, + "learning_rate": 1e-4, + "patches" : true, + "pretraining" : true, + "augmentation" : false, + "flip_aug" : false, + "elastic_aug" : false, + "blur_aug" : false, + "scaling" : false, + "binarization" : false, + "scaling_bluring" : false, + "scaling_binarization" : false, + "rotation": false, + "weighted_loss": true, + "dir_train": "/home/vahid/textline_gt_images/train_light", + "dir_eval": "/home/vahid/textline_gt_images/eval", + "dir_output": "/home/vahid/textline_gt_images/output" +} diff --git a/train.py b/train.py new file mode 100644 index 0000000..07c7418 --- /dev/null +++ b/train.py @@ -0,0 +1,192 @@ +import os +import sys +import tensorflow as tf +from keras.backend.tensorflow_backend import set_session +import keras , warnings +from keras.optimizers import * +from sacred import Experiment +from models import * +from utils import * +from metrics import * + + +def configuration(): + keras.backend.clear_session() + tf.reset_default_graph() + warnings.filterwarnings('ignore') + + os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' + config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) + + + config.gpu_options.allow_growth = True + config.gpu_options.per_process_gpu_memory_fraction=0.95#0.95 + config.gpu_options.visible_device_list="0" + set_session(tf.Session(config=config)) + +def get_dirs_or_files(input_data): + if os.path.isdir(input_data): + image_input, labels_input = os.path.join(input_data, 'images/'), os.path.join(input_data, 'labels/') + # Check if training dir exists + assert os.path.isdir(image_input), "{} is not a directory".format(image_input) + assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) + return image_input, labels_input + +ex = Experiment() + +@ex.config +def config_params(): + n_classes=None # Number of classes. If your case study is binary case the set it to 2 and otherwise give your number of cases. + n_epochs=1 + input_height=224*1 + input_width=224*1 + weight_decay=1e-6 # Weight decay of l2 regularization of model layers. + n_batch=1 # Number of batches at each iteration. + learning_rate=1e-4 + patches=False # Make patches of image in order to use all information of image. In the case of page + # extraction this should be set to false since model should see all image. + augmentation=False + flip_aug=False # Flip image (augmentation). + elastic_aug=False # Elastic transformation (augmentation). + blur_aug=False # Blur patches of image (augmentation). + scaling=False # Scaling of patches (augmentation) will be imposed if this set to true. + binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied. + dir_train=None # Directory of training dataset (sub-folders should be named images and labels). + dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels). + dir_output=None # Directory of output where the model should be saved. + pretraining=False # Set true to load pretrained weights of resnet50 encoder. + weighted_loss=False # Set True if classes are unbalanced and you want to use weighted loss function. + scaling_bluring=False + rotation: False + scaling_binarization=False + blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. + scales=[0.9 , 1.1 ] # Scale patches with these scales. Used for augmentation. + flip_index=[0,1] # Flip image. Used for augmentation. + + +@ex.automain +def run(n_classes,n_epochs,input_height, + input_width,weight_decay,weighted_loss, + n_batch,patches,augmentation,flip_aug,blur_aug,scaling, binarization, + blur_k,scales,dir_train, + scaling_bluring,scaling_binarization,rotation, + flip_index,dir_eval ,dir_output,pretraining,learning_rate): + + dir_img,dir_seg=get_dirs_or_files(dir_train) + dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval) + + # make first a directory in output for both training and evaluations in order to flow data from these directories. + dir_train_flowing=os.path.join(dir_output,'train') + dir_eval_flowing=os.path.join(dir_output,'eval') + + dir_flow_train_imgs=os.path.join(dir_train_flowing,'images') + dir_flow_train_labels=os.path.join(dir_train_flowing,'labels') + + dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images') + dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels') + + if os.path.isdir(dir_train_flowing): + os.system('rm -rf '+dir_train_flowing) + os.makedirs(dir_train_flowing) + else: + os.makedirs(dir_train_flowing) + + if os.path.isdir(dir_eval_flowing): + os.system('rm -rf '+dir_eval_flowing) + os.makedirs(dir_eval_flowing) + else: + os.makedirs(dir_eval_flowing) + + + os.mkdir(dir_flow_train_imgs) + os.mkdir(dir_flow_train_labels) + + os.mkdir(dir_flow_eval_imgs) + os.mkdir(dir_flow_eval_labels) + + + + #set the gpu configuration + configuration() + + + #writing patches into a sub-folder in order to be flowed from directory. + provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + dir_flow_train_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=augmentation,patches=patches) + + provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs, + dir_flow_eval_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=False,patches=patches) + + if weighted_loss: + weights=np.zeros(n_classes) + for obj in os.listdir(dir_seg): + label_obj=cv2.imread(dir_seg+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + + + weights=1.00/weights + + weights=weights/float(np.sum(weights)) + weights=weights/float(np.min(weights)) + weights=weights/float(np.sum(weights)) + + + + + #get our model. + model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining) + + #if you want to see the model structure just uncomment model summary. + #model.summary() + + + if not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + + mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', + save_weights_only=True, period=1) + + + #generating train and evaluation data + train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, + input_height=input_height, input_width=input_width,n_classes=n_classes ) + val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, + input_height=input_height, input_width=input_width,n_classes=n_classes ) + + + model.fit_generator( + train_gen, + steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch), + validation_data=val_gen, + validation_steps=1, + epochs=n_epochs) + + + + os.system('rm -rf '+dir_train_flowing) + os.system('rm -rf '+dir_eval_flowing) + + model.save(dir_output+'/'+'model'+'.h5') + + + + + + + + + + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..afdc9e5 --- /dev/null +++ b/utils.py @@ -0,0 +1,336 @@ +import os +import cv2 +import numpy as np +import seaborn as sns +from scipy.ndimage.interpolation import map_coordinates +from scipy.ndimage.filters import gaussian_filter +import random +from tqdm import tqdm + + + + +def bluring(img_in,kind): + if kind=='guass': + img_blur = cv2.GaussianBlur(img_in,(5,5),0) + elif kind=="median": + img_blur = cv2.medianBlur(img_in,5) + elif kind=='blur': + img_blur=cv2.blur(img_in,(5,5)) + return img_blur + +def color_images(seg, n_classes): + ann_u=range(n_classes) + if len(np.shape(seg))==3: + seg=seg[:,:,0] + + seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(float) + colors=sns.color_palette("hls", n_classes) + + for c in ann_u: + c=int(c) + segl=(seg==c) + seg_img[:,:,0]+=segl*(colors[c][0]) + seg_img[:,:,1]+=segl*(colors[c][1]) + seg_img[:,:,2]+=segl*(colors[c][2]) + return seg_img + + +def resize_image(seg_in,input_height,input_width): + return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST) +def get_one_hot(seg,input_height,input_width,n_classes): + seg=seg[:,:,0] + seg_f=np.zeros((input_height, input_width,n_classes)) + for j in range(n_classes): + seg_f[:,:,j]=(seg==j).astype(int) + return seg_f + + +def IoU(Yi,y_predi): + ## mean Intersection over Union + ## Mean IoU = TP/(FN + TP + FP) + + IoUs = [] + classes_true=np.unique(Yi) + for c in classes_true: + TP = np.sum( (Yi == c)&(y_predi==c) ) + FP = np.sum( (Yi != c)&(y_predi==c) ) + FN = np.sum( (Yi == c)&(y_predi != c)) + IoU = TP/float(TP + FP + FN) + print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU)) + IoUs.append(IoU) + mIoU = np.mean(IoUs) + print("_________________") + print("Mean IoU: {:4.3f}".format(mIoU)) + return mIoU +def data_gen(img_folder, mask_folder, batch_size,input_height, input_width,n_classes): + c = 0 + n = os.listdir(img_folder) #List of training images + random.shuffle(n) + while True: + img = np.zeros((batch_size, input_height, input_width, 3)).astype('float') + mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float') + + for i in range(c, c+batch_size): #initially from 0 to 16, c = 0. + #print(img_folder+'/'+n[i]) + filename=n[i].split('.')[0] + train_img = cv2.imread(img_folder+'/'+n[i])/255. + train_img = cv2.resize(train_img, (input_width, input_height),interpolation=cv2.INTER_NEAREST)# Read an image from folder and resize + + img[i-c] = train_img #add to array - img[0], img[1], and so on. + train_mask = cv2.imread(mask_folder+'/'+filename+'.png') + #print(mask_folder+'/'+filename+'.png') + #print(train_mask.shape) + train_mask = get_one_hot( resize_image(train_mask,input_height,input_width),input_height,input_width,n_classes) + #train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] + + mask[i-c] = train_mask + + c+=batch_size + if(c+batch_size>=len(os.listdir(img_folder))): + c=0 + random.shuffle(n) + yield img, mask + +def otsu_copy(img): + img_r=np.zeros(img.shape) + img1=img[:,:,0] + img2=img[:,:,1] + img3=img[:,:,2] + _, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold2 = cv2.threshold(img2, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold3 = cv2.threshold(img3, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + img_r[:,:,0]=threshold1 + img_r[:,:,1]=threshold1 + img_r[:,:,2]=threshold1 + return img_r + +def rotation_90(img): + img_rot=np.zeros((img.shape[1],img.shape[0],img.shape[2])) + img_rot[:,:,0]=img[:,:,0].T + img_rot[:,:,1]=img[:,:,1].T + img_rot[:,:,2]=img[:,:,2].T + return img_rot + +def get_patches(dir_img_f,dir_seg_f,img,label,height,width,indexer): + + + img_h=img.shape[0] + img_w=img.shape[1] + + nxf=img_w/float(width) + nyf=img_h/float(height) + + if nxf>int(nxf): + nxf=int(nxf)+1 + if nyf>int(nyf): + nyf=int(nyf)+1 + + nxf=int(nxf) + nyf=int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d=i*width + index_x_u=(i+1)*width + + index_y_d=j*height + index_y_u=(j+1)*height + + if index_x_u>img_w: + index_x_u=img_w + index_x_d=img_w-width + if index_y_u>img_h: + index_y_u=img_h + index_y_d=img_h-height + + + img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] + label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] + + cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) + cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) + indexer+=1 + return indexer + + + +def get_patches_num_scale(dir_img_f,dir_seg_f,img,label,height,width,indexer,scaler): + + + img_h=img.shape[0] + img_w=img.shape[1] + + height_scale=int(height*scaler) + width_scale=int(width*scaler) + + + nxf=img_w/float(width_scale) + nyf=img_h/float(height_scale) + + if nxf>int(nxf): + nxf=int(nxf)+1 + if nyf>int(nyf): + nyf=int(nyf)+1 + + nxf=int(nxf) + nyf=int(nyf) + + for i in range(nxf): + for j in range(nyf): + index_x_d=i*width_scale + index_x_u=(i+1)*width_scale + + index_y_d=j*height_scale + index_y_u=(j+1)*height_scale + + if index_x_u>img_w: + index_x_u=img_w + index_x_d=img_w-width_scale + if index_y_u>img_h: + index_y_u=img_h + index_y_d=img_h-height_scale + + + img_patch=img[index_y_d:index_y_u,index_x_d:index_x_u,:] + label_patch=label[index_y_d:index_y_u,index_x_d:index_x_u,:] + + img_patch=resize_image(img_patch,height,width) + label_patch=resize_image(label_patch,height,width) + + cv2.imwrite(dir_img_f+'/img_'+str(indexer)+'.png', img_patch ) + cv2.imwrite(dir_seg_f+'/img_'+str(indexer)+'.png' , label_patch ) + indexer+=1 + + return indexer + + + +def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, + dir_flow_train_labels, + input_height,input_width,blur_k,blur_aug, + flip_aug,binarization,scaling,scales,flip_index, + scaling_bluring,scaling_binarization,rotation, + augmentation=False,patches=False): + + imgs_cv_train=np.array(os.listdir(dir_img)) + segs_cv_train=np.array(os.listdir(dir_seg)) + + indexer=0 + for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): + img_name=im.split('.')[0] + + if not patches: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) ) + indexer+=1 + + if augmentation: + if rotation: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + rotation_90( resize_image(cv2.imread(dir_img+'/'+im), + input_height,input_width) ) ) + + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', + rotation_90 ( resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width) ) ) + indexer+=1 + + if flip_aug: + for f_i in flip_index: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + resize_image(cv2.flip(cv2.imread(dir_img+'/'+im),f_i),input_height,input_width) ) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , + resize_image(cv2.flip(cv2.imread(dir_seg+'/'+img_name+'.png'),f_i),input_height,input_width) ) + indexer+=1 + + if blur_aug: + for blur_i in blur_k: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + (resize_image(bluring(cv2.imread(dir_img+'/'+im),blur_i),input_height,input_width) ) ) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , + resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width) ) + indexer+=1 + + + if binarization: + cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', + resize_image(otsu_copy( cv2.imread(dir_img+'/'+im)),input_height,input_width )) + + cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png', + resize_image( cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width )) + indexer+=1 + + + + + + + if patches: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + cv2.imread(dir_img+'/'+im),cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + if augmentation: + + if rotation: + + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + rotation_90( cv2.imread(dir_img+'/'+im) ), + rotation_90( cv2.imread(dir_seg+'/'+img_name+'.png') ), + input_height,input_width,indexer=indexer) + if flip_aug: + for f_i in flip_index: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + cv2.flip( cv2.imread(dir_img+'/'+im) , f_i), + cv2.flip( cv2.imread(dir_seg+'/'+img_name+'.png') ,f_i), + input_height,input_width,indexer=indexer) + if blur_aug: + for blur_i in blur_k: + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + bluring( cv2.imread(dir_img+'/'+im) , blur_i), + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + + if scaling: + for sc_ind in scales: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + cv2.imread(dir_img+'/'+im) , + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer,scaler=sc_ind) + if binarization: + + indexer=get_patches(dir_flow_train_imgs,dir_flow_train_labels, + otsu_copy( cv2.imread(dir_img+'/'+im)), + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer) + + + + if scaling_bluring: + for sc_ind in scales: + for blur_i in blur_k: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + bluring( cv2.imread(dir_img+'/'+im) , blur_i) , + cv2.imread(dir_seg+'/'+img_name+'.png') , + input_height,input_width,indexer=indexer,scaler=sc_ind) + + if scaling_binarization: + for sc_ind in scales: + indexer=get_patches_num_scale(dir_flow_train_imgs,dir_flow_train_labels, + otsu_copy( cv2.imread(dir_img+'/'+im)) , + cv2.imread(dir_seg+'/'+img_name+'.png'), + input_height,input_width,indexer=indexer,scaler=sc_ind) + + + + + +