continue training, losses and etc

This commit is contained in:
vahid 2021-06-22 18:47:59 -04:00
parent 5fb7552dbe
commit 4bea9fd535
3 changed files with 69 additions and 24 deletions

View file

@ -9,6 +9,7 @@ from models import *
from utils import *
from metrics import *
from keras.models import load_model
from tqdm import tqdm
def configuration():
keras.backend.clear_session()
@ -61,19 +62,24 @@ def config_params():
blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation.
scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation.
flip_index=[0,1,-1] # Flip image. Used for augmentation.
continue_training = False # If
index_start = 0
dir_of_start_model = ''
is_loss_soft_dice = False
weighted_loss = False
data_is_provided = False
@ex.automain
def run(n_classes,n_epochs,input_height,
input_width,weight_decay,
input_width,weight_decay,weighted_loss,
index_start,dir_of_start_model,is_loss_soft_dice,
n_batch,patches,augmentation,flip_aug
,blur_aug,scaling, binarization,
blur_k,scales,dir_train,
blur_k,scales,dir_train,data_is_provided,
scaling_bluring,scaling_binarization,rotation,
rotation_not_90,thetha,scaling_flip,
rotation_not_90,thetha,scaling_flip,continue_training,
flip_index,dir_eval ,dir_output,pretraining,learning_rate):
data_is_provided = False
if data_is_provided:
dir_train_flowing=os.path.join(dir_output,'train')
@ -143,12 +149,43 @@ def run(n_classes,n_epochs,input_height,
augmentation=False,patches=patches)
continue_train = False
if weighted_loss:
weights=np.zeros(n_classes)
if data_is_provided:
for obj in os.listdir(dir_flow_train_labels):
try:
label_obj=cv2.imread(dir_flow_train_labels+'/'+obj)
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
except:
pass
else:
for obj in os.listdir(dir_seg):
try:
label_obj=cv2.imread(dir_seg+'/'+obj)
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
except:
pass
if continue_train:
model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5'
model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss})
index_start = 1
weights=1.00/weights
weights=weights/float(np.sum(weights))
weights=weights/float(np.min(weights))
weights=weights/float(np.sum(weights))
if continue_training:
if is_loss_soft_dice:
model = load_model (dir_of_start_model, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss})
if weighted_loss:
model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss:
model = load_model (dir_of_start_model, compile = True)
else:
#get our model.
index_start = 0
@ -158,12 +195,16 @@ def run(n_classes,n_epochs,input_height,
#model.summary()
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
if is_loss_soft_dice:
model.compile(loss=soft_dice_loss,
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
#model.compile(loss='categorical_crossentropy',
#optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
model.compile(loss=soft_dice_loss,
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
#generating train and evaluation data
train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch,
@ -171,7 +212,7 @@ def run(n_classes,n_epochs,input_height,
val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch,
input_height=input_height, input_width=input_width,n_classes=n_classes )
for i in range(index_start, n_epochs+index_start):
for i in tqdm(range(index_start, n_epochs+index_start)):
model.fit_generator(
train_gen,
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1,
@ -181,8 +222,8 @@ def run(n_classes,n_epochs,input_height,
model.save(dir_output+'/'+'model_'+str(i)+'.h5')
os.system('rm -rf '+dir_train_flowing)
os.system('rm -rf '+dir_eval_flowing)
#os.system('rm -rf '+dir_train_flowing)
#os.system('rm -rf '+dir_eval_flowing)
#model.save(dir_output+'/'+'model'+'.h5')