From 4bea9fd5354dfe1f78f9b84d20a97714b54ab37d Mon Sep 17 00:00:00 2001 From: vahid Date: Tue, 22 Jun 2021 18:47:59 -0400 Subject: [PATCH] continue training, losses and etc --- config_params.json | 14 ++++++--- train.py | 77 +++++++++++++++++++++++++++++++++++----------- utils.py | 2 -- 3 files changed, 69 insertions(+), 24 deletions(-) diff --git a/config_params.json b/config_params.json index d8f1ac5..eaa50e1 100644 --- a/config_params.json +++ b/config_params.json @@ -1,6 +1,6 @@ { "n_classes" : 3, - "n_epochs" : 1, + "n_epochs" : 2, "input_height" : 448, "input_width" : 672, "weight_decay" : 1e-6, @@ -8,16 +8,22 @@ "learning_rate": 1e-4, "patches" : true, "pretraining" : true, - "augmentation" : true, + "augmentation" : false, "flip_aug" : false, - "blur_aug" : true, - "scaling" : false, + "blur_aug" : false, + "scaling" : true, "binarization" : false, "scaling_bluring" : false, "scaling_binarization" : false, "scaling_flip" : false, "rotation": false, "rotation_not_90": false, + "continue_training": false, + "index_start": 0, + "dir_of_start_model": " ", + "weighted_loss": false, + "is_loss_soft_dice": false, + "data_is_provided": false, "dir_train": "/home/vahid/Documents/handwrittens_train/train", "dir_eval": "/home/vahid/Documents/handwrittens_train/eval", "dir_output": "/home/vahid/Documents/handwrittens_train/output" diff --git a/train.py b/train.py index c256d83..0cc5ef3 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ from models import * from utils import * from metrics import * from keras.models import load_model +from tqdm import tqdm def configuration(): keras.backend.clear_session() @@ -61,19 +62,24 @@ def config_params(): blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation. flip_index=[0,1,-1] # Flip image. Used for augmentation. - + continue_training = False # If + index_start = 0 + dir_of_start_model = '' + is_loss_soft_dice = False + weighted_loss = False + data_is_provided = False @ex.automain def run(n_classes,n_epochs,input_height, - input_width,weight_decay, + input_width,weight_decay,weighted_loss, + index_start,dir_of_start_model,is_loss_soft_dice, n_batch,patches,augmentation,flip_aug ,blur_aug,scaling, binarization, - blur_k,scales,dir_train, + blur_k,scales,dir_train,data_is_provided, scaling_bluring,scaling_binarization,rotation, - rotation_not_90,thetha,scaling_flip, + rotation_not_90,thetha,scaling_flip,continue_training, flip_index,dir_eval ,dir_output,pretraining,learning_rate): - data_is_provided = False if data_is_provided: dir_train_flowing=os.path.join(dir_output,'train') @@ -143,12 +149,43 @@ def run(n_classes,n_epochs,input_height, augmentation=False,patches=patches) - continue_train = False + + if weighted_loss: + weights=np.zeros(n_classes) + if data_is_provided: + for obj in os.listdir(dir_flow_train_labels): + try: + label_obj=cv2.imread(dir_flow_train_labels+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + else: + + for obj in os.listdir(dir_seg): + try: + label_obj=cv2.imread(dir_seg+'/'+obj) + label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) + weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) + except: + pass + - if continue_train: - model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5' - model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) - index_start = 1 + weights=1.00/weights + + weights=weights/float(np.sum(weights)) + weights=weights/float(np.min(weights)) + weights=weights/float(np.sum(weights)) + + + + if continue_training: + if is_loss_soft_dice: + model = load_model (dir_of_start_model, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) + if weighted_loss: + model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) + if not is_loss_soft_dice and not weighted_loss: + model = load_model (dir_of_start_model, compile = True) else: #get our model. index_start = 0 @@ -158,12 +195,16 @@ def run(n_classes,n_epochs,input_height, #model.summary() + if not is_loss_soft_dice and not weighted_loss: + model.compile(loss='categorical_crossentropy', + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if is_loss_soft_dice: + model.compile(loss=soft_dice_loss, + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - #model.compile(loss='categorical_crossentropy', - #optimizer = Adam(lr=learning_rate),metrics=['accuracy']) - - model.compile(loss=soft_dice_loss, - optimizer = Adam(lr=learning_rate),metrics=['accuracy']) + if weighted_loss: + model.compile(loss=weighted_categorical_crossentropy(weights), + optimizer = Adam(lr=learning_rate),metrics=['accuracy']) #generating train and evaluation data train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch, @@ -171,7 +212,7 @@ def run(n_classes,n_epochs,input_height, val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch, input_height=input_height, input_width=input_width,n_classes=n_classes ) - for i in range(index_start, n_epochs+index_start): + for i in tqdm(range(index_start, n_epochs+index_start)): model.fit_generator( train_gen, steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, @@ -181,8 +222,8 @@ def run(n_classes,n_epochs,input_height, model.save(dir_output+'/'+'model_'+str(i)+'.h5') - os.system('rm -rf '+dir_train_flowing) - os.system('rm -rf '+dir_eval_flowing) + #os.system('rm -rf '+dir_train_flowing) + #os.system('rm -rf '+dir_eval_flowing) #model.save(dir_output+'/'+'model'+'.h5') diff --git a/utils.py b/utils.py index a77444e..19ab46e 100644 --- a/utils.py +++ b/utils.py @@ -374,9 +374,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, indexer=0 for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): - #print(im, seg_i) img_name=im.split('.')[0] - print(img_name,'img_name') if not patches: cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) )