mirror of
				https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
				synced 2025-10-31 01:14:19 +01:00 
			
		
		
		
	continue training, losses and etc
This commit is contained in:
		
							parent
							
								
									5fb7552dbe
								
							
						
					
					
						commit
						4bea9fd535
					
				
					 3 changed files with 69 additions and 24 deletions
				
			
		|  | @ -1,6 +1,6 @@ | |||
| { | ||||
|     "n_classes" : 3, | ||||
|     "n_epochs" : 1, | ||||
|     "n_epochs" : 2, | ||||
|     "input_height" : 448, | ||||
|     "input_width" : 672, | ||||
|     "weight_decay" : 1e-6, | ||||
|  | @ -8,16 +8,22 @@ | |||
|     "learning_rate": 1e-4, | ||||
|     "patches" : true, | ||||
|     "pretraining" : true, | ||||
|     "augmentation" : true, | ||||
|     "augmentation" : false, | ||||
|     "flip_aug" : false, | ||||
|     "blur_aug" : true, | ||||
|     "scaling" : false, | ||||
|     "blur_aug" : false, | ||||
|     "scaling" : true, | ||||
|     "binarization" : false, | ||||
|     "scaling_bluring" : false, | ||||
|     "scaling_binarization" : false, | ||||
|     "scaling_flip" : false, | ||||
|     "rotation": false, | ||||
|     "rotation_not_90": false, | ||||
|     "continue_training": false, | ||||
|     "index_start": 0, | ||||
|     "dir_of_start_model": " ",  | ||||
|     "weighted_loss": false, | ||||
|     "is_loss_soft_dice": false, | ||||
|     "data_is_provided": false, | ||||
|     "dir_train": "/home/vahid/Documents/handwrittens_train/train", | ||||
|     "dir_eval": "/home/vahid/Documents/handwrittens_train/eval", | ||||
|     "dir_output": "/home/vahid/Documents/handwrittens_train/output" | ||||
|  |  | |||
							
								
								
									
										77
									
								
								train.py
									
										
									
									
									
								
							
							
						
						
									
										77
									
								
								train.py
									
										
									
									
									
								
							|  | @ -9,6 +9,7 @@ from models import * | |||
| from utils import * | ||||
| from metrics import * | ||||
| from keras.models import load_model | ||||
| from tqdm import tqdm | ||||
| 
 | ||||
| def configuration(): | ||||
|     keras.backend.clear_session() | ||||
|  | @ -61,19 +62,24 @@ def config_params(): | |||
|     blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation. | ||||
|     scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation. | ||||
|     flip_index=[0,1,-1] # Flip image. Used for augmentation. | ||||
| 
 | ||||
|     continue_training = False # If  | ||||
|     index_start = 0 | ||||
|     dir_of_start_model = '' | ||||
|     is_loss_soft_dice = False | ||||
|     weighted_loss = False | ||||
|     data_is_provided = False | ||||
| 
 | ||||
| @ex.automain | ||||
| def run(n_classes,n_epochs,input_height, | ||||
|         input_width,weight_decay, | ||||
|         input_width,weight_decay,weighted_loss, | ||||
|         index_start,dir_of_start_model,is_loss_soft_dice, | ||||
|         n_batch,patches,augmentation,flip_aug | ||||
|         ,blur_aug,scaling, binarization, | ||||
|         blur_k,scales,dir_train, | ||||
|         blur_k,scales,dir_train,data_is_provided, | ||||
|         scaling_bluring,scaling_binarization,rotation, | ||||
|         rotation_not_90,thetha,scaling_flip, | ||||
|         rotation_not_90,thetha,scaling_flip,continue_training, | ||||
|         flip_index,dir_eval ,dir_output,pretraining,learning_rate): | ||||
|      | ||||
|     data_is_provided = False | ||||
|      | ||||
|     if data_is_provided: | ||||
|         dir_train_flowing=os.path.join(dir_output,'train') | ||||
|  | @ -143,12 +149,43 @@ def run(n_classes,n_epochs,input_height, | |||
|                         augmentation=False,patches=patches) | ||||
|          | ||||
| 
 | ||||
|     continue_train = False | ||||
|      | ||||
|     if continue_train: | ||||
|         model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5' | ||||
|         model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) | ||||
|         index_start = 1 | ||||
|     if weighted_loss: | ||||
|         weights=np.zeros(n_classes) | ||||
|         if data_is_provided: | ||||
|             for obj in os.listdir(dir_flow_train_labels): | ||||
|                 try: | ||||
|                     label_obj=cv2.imread(dir_flow_train_labels+'/'+obj) | ||||
|                     label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) | ||||
|                     weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) | ||||
|                 except: | ||||
|                     pass | ||||
|         else: | ||||
|              | ||||
|             for obj in os.listdir(dir_seg): | ||||
|                 try: | ||||
|                     label_obj=cv2.imread(dir_seg+'/'+obj) | ||||
|                     label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes) | ||||
|                     weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0) | ||||
|                 except: | ||||
|                     pass | ||||
|              | ||||
| 
 | ||||
|         weights=1.00/weights | ||||
|          | ||||
|         weights=weights/float(np.sum(weights)) | ||||
|         weights=weights/float(np.min(weights)) | ||||
|         weights=weights/float(np.sum(weights)) | ||||
|          | ||||
|          | ||||
|          | ||||
|     if continue_training: | ||||
|         if is_loss_soft_dice: | ||||
|             model = load_model (dir_of_start_model, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss}) | ||||
|         if weighted_loss: | ||||
|             model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) | ||||
|         if not is_loss_soft_dice and not weighted_loss: | ||||
|             model = load_model (dir_of_start_model, compile = True) | ||||
|     else: | ||||
|         #get our model. | ||||
|         index_start = 0 | ||||
|  | @ -158,12 +195,16 @@ def run(n_classes,n_epochs,input_height, | |||
|     #model.summary() | ||||
|      | ||||
| 
 | ||||
|     if not is_loss_soft_dice and not weighted_loss: | ||||
|         model.compile(loss='categorical_crossentropy', | ||||
|                             optimizer = Adam(lr=learning_rate),metrics=['accuracy']) | ||||
|     if is_loss_soft_dice:                     | ||||
|         model.compile(loss=soft_dice_loss, | ||||
|                             optimizer = Adam(lr=learning_rate),metrics=['accuracy']) | ||||
|      | ||||
|     #model.compile(loss='categorical_crossentropy', | ||||
|                         #optimizer = Adam(lr=learning_rate),metrics=['accuracy']) | ||||
|                          | ||||
|     model.compile(loss=soft_dice_loss, | ||||
|                         optimizer = Adam(lr=learning_rate),metrics=['accuracy']) | ||||
|     if weighted_loss: | ||||
|         model.compile(loss=weighted_categorical_crossentropy(weights), | ||||
|                             optimizer = Adam(lr=learning_rate),metrics=['accuracy']) | ||||
|      | ||||
|     #generating train and evaluation data | ||||
|     train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size =  n_batch, | ||||
|  | @ -171,7 +212,7 @@ def run(n_classes,n_epochs,input_height, | |||
|     val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size =  n_batch, | ||||
|                          input_height=input_height, input_width=input_width,n_classes=n_classes  ) | ||||
|      | ||||
|     for i in range(index_start, n_epochs+index_start): | ||||
|     for i in tqdm(range(index_start, n_epochs+index_start)): | ||||
|         model.fit_generator( | ||||
|             train_gen, | ||||
|             steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1, | ||||
|  | @ -181,8 +222,8 @@ def run(n_classes,n_epochs,input_height, | |||
|         model.save(dir_output+'/'+'model_'+str(i)+'.h5') | ||||
|      | ||||
| 
 | ||||
|     os.system('rm -rf '+dir_train_flowing) | ||||
|     os.system('rm -rf '+dir_eval_flowing) | ||||
|     #os.system('rm -rf '+dir_train_flowing) | ||||
|     #os.system('rm -rf '+dir_eval_flowing) | ||||
| 
 | ||||
|     #model.save(dir_output+'/'+'model'+'.h5') | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										2
									
								
								utils.py
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								utils.py
									
										
									
									
									
								
							|  | @ -374,9 +374,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs, | |||
|      | ||||
|     indexer=0 | ||||
|     for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)): | ||||
|         #print(im, seg_i) | ||||
|         img_name=im.split('.')[0] | ||||
|         print(img_name,'img_name') | ||||
|         if not patches: | ||||
|             cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) ) | ||||
|             cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' ,  resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) ) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue