mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-07 19:05:24 +02:00
continue training, losses and etc
This commit is contained in:
parent
5fb7552dbe
commit
4bea9fd535
3 changed files with 69 additions and 24 deletions
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"n_classes" : 3,
|
"n_classes" : 3,
|
||||||
"n_epochs" : 1,
|
"n_epochs" : 2,
|
||||||
"input_height" : 448,
|
"input_height" : 448,
|
||||||
"input_width" : 672,
|
"input_width" : 672,
|
||||||
"weight_decay" : 1e-6,
|
"weight_decay" : 1e-6,
|
||||||
|
@ -8,16 +8,22 @@
|
||||||
"learning_rate": 1e-4,
|
"learning_rate": 1e-4,
|
||||||
"patches" : true,
|
"patches" : true,
|
||||||
"pretraining" : true,
|
"pretraining" : true,
|
||||||
"augmentation" : true,
|
"augmentation" : false,
|
||||||
"flip_aug" : false,
|
"flip_aug" : false,
|
||||||
"blur_aug" : true,
|
"blur_aug" : false,
|
||||||
"scaling" : false,
|
"scaling" : true,
|
||||||
"binarization" : false,
|
"binarization" : false,
|
||||||
"scaling_bluring" : false,
|
"scaling_bluring" : false,
|
||||||
"scaling_binarization" : false,
|
"scaling_binarization" : false,
|
||||||
"scaling_flip" : false,
|
"scaling_flip" : false,
|
||||||
"rotation": false,
|
"rotation": false,
|
||||||
"rotation_not_90": false,
|
"rotation_not_90": false,
|
||||||
|
"continue_training": false,
|
||||||
|
"index_start": 0,
|
||||||
|
"dir_of_start_model": " ",
|
||||||
|
"weighted_loss": false,
|
||||||
|
"is_loss_soft_dice": false,
|
||||||
|
"data_is_provided": false,
|
||||||
"dir_train": "/home/vahid/Documents/handwrittens_train/train",
|
"dir_train": "/home/vahid/Documents/handwrittens_train/train",
|
||||||
"dir_eval": "/home/vahid/Documents/handwrittens_train/eval",
|
"dir_eval": "/home/vahid/Documents/handwrittens_train/eval",
|
||||||
"dir_output": "/home/vahid/Documents/handwrittens_train/output"
|
"dir_output": "/home/vahid/Documents/handwrittens_train/output"
|
||||||
|
|
77
train.py
77
train.py
|
@ -9,6 +9,7 @@ from models import *
|
||||||
from utils import *
|
from utils import *
|
||||||
from metrics import *
|
from metrics import *
|
||||||
from keras.models import load_model
|
from keras.models import load_model
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
def configuration():
|
def configuration():
|
||||||
keras.backend.clear_session()
|
keras.backend.clear_session()
|
||||||
|
@ -61,19 +62,24 @@ def config_params():
|
||||||
blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation.
|
blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation.
|
||||||
scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation.
|
scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation.
|
||||||
flip_index=[0,1,-1] # Flip image. Used for augmentation.
|
flip_index=[0,1,-1] # Flip image. Used for augmentation.
|
||||||
|
continue_training = False # If
|
||||||
|
index_start = 0
|
||||||
|
dir_of_start_model = ''
|
||||||
|
is_loss_soft_dice = False
|
||||||
|
weighted_loss = False
|
||||||
|
data_is_provided = False
|
||||||
|
|
||||||
@ex.automain
|
@ex.automain
|
||||||
def run(n_classes,n_epochs,input_height,
|
def run(n_classes,n_epochs,input_height,
|
||||||
input_width,weight_decay,
|
input_width,weight_decay,weighted_loss,
|
||||||
|
index_start,dir_of_start_model,is_loss_soft_dice,
|
||||||
n_batch,patches,augmentation,flip_aug
|
n_batch,patches,augmentation,flip_aug
|
||||||
,blur_aug,scaling, binarization,
|
,blur_aug,scaling, binarization,
|
||||||
blur_k,scales,dir_train,
|
blur_k,scales,dir_train,data_is_provided,
|
||||||
scaling_bluring,scaling_binarization,rotation,
|
scaling_bluring,scaling_binarization,rotation,
|
||||||
rotation_not_90,thetha,scaling_flip,
|
rotation_not_90,thetha,scaling_flip,continue_training,
|
||||||
flip_index,dir_eval ,dir_output,pretraining,learning_rate):
|
flip_index,dir_eval ,dir_output,pretraining,learning_rate):
|
||||||
|
|
||||||
data_is_provided = False
|
|
||||||
|
|
||||||
if data_is_provided:
|
if data_is_provided:
|
||||||
dir_train_flowing=os.path.join(dir_output,'train')
|
dir_train_flowing=os.path.join(dir_output,'train')
|
||||||
|
@ -143,12 +149,43 @@ def run(n_classes,n_epochs,input_height,
|
||||||
augmentation=False,patches=patches)
|
augmentation=False,patches=patches)
|
||||||
|
|
||||||
|
|
||||||
continue_train = False
|
|
||||||
|
|
||||||
if continue_train:
|
if weighted_loss:
|
||||||
model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5'
|
weights=np.zeros(n_classes)
|
||||||
model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
if data_is_provided:
|
||||||
index_start = 1
|
for obj in os.listdir(dir_flow_train_labels):
|
||||||
|
try:
|
||||||
|
label_obj=cv2.imread(dir_flow_train_labels+'/'+obj)
|
||||||
|
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
|
||||||
|
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
|
||||||
|
for obj in os.listdir(dir_seg):
|
||||||
|
try:
|
||||||
|
label_obj=cv2.imread(dir_seg+'/'+obj)
|
||||||
|
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
|
||||||
|
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
weights=1.00/weights
|
||||||
|
|
||||||
|
weights=weights/float(np.sum(weights))
|
||||||
|
weights=weights/float(np.min(weights))
|
||||||
|
weights=weights/float(np.sum(weights))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if continue_training:
|
||||||
|
if is_loss_soft_dice:
|
||||||
|
model = load_model (dir_of_start_model, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||||
|
if weighted_loss:
|
||||||
|
model = load_model (dir_of_start_model, compile = True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||||
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
|
model = load_model (dir_of_start_model, compile = True)
|
||||||
else:
|
else:
|
||||||
#get our model.
|
#get our model.
|
||||||
index_start = 0
|
index_start = 0
|
||||||
|
@ -158,12 +195,16 @@ def run(n_classes,n_epochs,input_height,
|
||||||
#model.summary()
|
#model.summary()
|
||||||
|
|
||||||
|
|
||||||
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
|
model.compile(loss='categorical_crossentropy',
|
||||||
|
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||||
|
if is_loss_soft_dice:
|
||||||
|
model.compile(loss=soft_dice_loss,
|
||||||
|
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||||
|
|
||||||
#model.compile(loss='categorical_crossentropy',
|
if weighted_loss:
|
||||||
#optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||||
|
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||||
model.compile(loss=soft_dice_loss,
|
|
||||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
|
||||||
|
|
||||||
#generating train and evaluation data
|
#generating train and evaluation data
|
||||||
train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch,
|
train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch,
|
||||||
|
@ -171,7 +212,7 @@ def run(n_classes,n_epochs,input_height,
|
||||||
val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch,
|
val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch,
|
||||||
input_height=input_height, input_width=input_width,n_classes=n_classes )
|
input_height=input_height, input_width=input_width,n_classes=n_classes )
|
||||||
|
|
||||||
for i in range(index_start, n_epochs+index_start):
|
for i in tqdm(range(index_start, n_epochs+index_start)):
|
||||||
model.fit_generator(
|
model.fit_generator(
|
||||||
train_gen,
|
train_gen,
|
||||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1,
|
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1,
|
||||||
|
@ -181,8 +222,8 @@ def run(n_classes,n_epochs,input_height,
|
||||||
model.save(dir_output+'/'+'model_'+str(i)+'.h5')
|
model.save(dir_output+'/'+'model_'+str(i)+'.h5')
|
||||||
|
|
||||||
|
|
||||||
os.system('rm -rf '+dir_train_flowing)
|
#os.system('rm -rf '+dir_train_flowing)
|
||||||
os.system('rm -rf '+dir_eval_flowing)
|
#os.system('rm -rf '+dir_eval_flowing)
|
||||||
|
|
||||||
#model.save(dir_output+'/'+'model'+'.h5')
|
#model.save(dir_output+'/'+'model'+'.h5')
|
||||||
|
|
||||||
|
|
2
utils.py
2
utils.py
|
@ -374,9 +374,7 @@ def provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
||||||
|
|
||||||
indexer=0
|
indexer=0
|
||||||
for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)):
|
for im, seg_i in tqdm(zip(imgs_cv_train,segs_cv_train)):
|
||||||
#print(im, seg_i)
|
|
||||||
img_name=im.split('.')[0]
|
img_name=im.split('.')[0]
|
||||||
print(img_name,'img_name')
|
|
||||||
if not patches:
|
if not patches:
|
||||||
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) )
|
cv2.imwrite(dir_flow_train_imgs+'/img_'+str(indexer)+'.png', resize_image(cv2.imread(dir_img+'/'+im),input_height,input_width ) )
|
||||||
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) )
|
cv2.imwrite(dir_flow_train_labels+'/img_'+str(indexer)+'.png' , resize_image(cv2.imread(dir_seg+'/'+img_name+'.png'),input_height,input_width ) )
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue