mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-08 11:20:48 +02:00
first updates, padding, rotations
This commit is contained in:
parent
63fcb96189
commit
5fb7552dbe
3 changed files with 319 additions and 151 deletions
183
train.py
183
train.py
|
@ -8,7 +8,7 @@ from sacred import Experiment
|
|||
from models import *
|
||||
from utils import *
|
||||
from metrics import *
|
||||
|
||||
from keras.models import load_model
|
||||
|
||||
def configuration():
|
||||
keras.backend.clear_session()
|
||||
|
@ -47,7 +47,6 @@ def config_params():
|
|||
# extraction this should be set to false since model should see all image.
|
||||
augmentation=False
|
||||
flip_aug=False # Flip image (augmentation).
|
||||
elastic_aug=False # Elastic transformation (augmentation).
|
||||
blur_aug=False # Blur patches of image (augmentation).
|
||||
scaling=False # Scaling of patches (augmentation) will be imposed if this set to true.
|
||||
binarization=False # Otsu thresholding. Used for augmentation in the case of binary case like textline prediction. For multicases should not be applied.
|
||||
|
@ -55,110 +54,116 @@ def config_params():
|
|||
dir_eval=None # Directory of validation dataset (sub-folders should be named images and labels).
|
||||
dir_output=None # Directory of output where the model should be saved.
|
||||
pretraining=False # Set true to load pretrained weights of resnet50 encoder.
|
||||
weighted_loss=False # Set True if classes are unbalanced and you want to use weighted loss function.
|
||||
scaling_bluring=False
|
||||
rotation: False
|
||||
scaling_binarization=False
|
||||
scaling_flip=False
|
||||
thetha=[10,-10]
|
||||
blur_k=['blur','guass','median'] # Used in order to blur image. Used for augmentation.
|
||||
scales=[0.9 , 1.1 ] # Scale patches with these scales. Used for augmentation.
|
||||
flip_index=[0,1] # Flip image. Used for augmentation.
|
||||
scales= [ 0.5, 2 ] # Scale patches with these scales. Used for augmentation.
|
||||
flip_index=[0,1,-1] # Flip image. Used for augmentation.
|
||||
|
||||
|
||||
@ex.automain
|
||||
def run(n_classes,n_epochs,input_height,
|
||||
input_width,weight_decay,weighted_loss,
|
||||
n_batch,patches,augmentation,flip_aug,blur_aug,scaling, binarization,
|
||||
input_width,weight_decay,
|
||||
n_batch,patches,augmentation,flip_aug
|
||||
,blur_aug,scaling, binarization,
|
||||
blur_k,scales,dir_train,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
rotation_not_90,thetha,scaling_flip,
|
||||
flip_index,dir_eval ,dir_output,pretraining,learning_rate):
|
||||
|
||||
dir_img,dir_seg=get_dirs_or_files(dir_train)
|
||||
dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval)
|
||||
data_is_provided = False
|
||||
|
||||
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
||||
dir_train_flowing=os.path.join(dir_output,'train')
|
||||
dir_eval_flowing=os.path.join(dir_output,'eval')
|
||||
|
||||
dir_flow_train_imgs=os.path.join(dir_train_flowing,'images')
|
||||
dir_flow_train_labels=os.path.join(dir_train_flowing,'labels')
|
||||
|
||||
dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images')
|
||||
dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels')
|
||||
|
||||
if os.path.isdir(dir_train_flowing):
|
||||
os.system('rm -rf '+dir_train_flowing)
|
||||
os.makedirs(dir_train_flowing)
|
||||
if data_is_provided:
|
||||
dir_train_flowing=os.path.join(dir_output,'train')
|
||||
dir_eval_flowing=os.path.join(dir_output,'eval')
|
||||
|
||||
dir_flow_train_imgs=os.path.join(dir_train_flowing,'images')
|
||||
dir_flow_train_labels=os.path.join(dir_train_flowing,'labels')
|
||||
|
||||
dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images')
|
||||
dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels')
|
||||
|
||||
configuration()
|
||||
|
||||
else:
|
||||
os.makedirs(dir_train_flowing)
|
||||
dir_img,dir_seg=get_dirs_or_files(dir_train)
|
||||
dir_img_val,dir_seg_val=get_dirs_or_files(dir_eval)
|
||||
|
||||
if os.path.isdir(dir_eval_flowing):
|
||||
os.system('rm -rf '+dir_eval_flowing)
|
||||
os.makedirs(dir_eval_flowing)
|
||||
else:
|
||||
os.makedirs(dir_eval_flowing)
|
||||
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
||||
dir_train_flowing=os.path.join(dir_output,'train')
|
||||
dir_eval_flowing=os.path.join(dir_output,'eval')
|
||||
|
||||
|
||||
os.mkdir(dir_flow_train_imgs)
|
||||
os.mkdir(dir_flow_train_labels)
|
||||
|
||||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
|
||||
|
||||
#set the gpu configuration
|
||||
configuration()
|
||||
|
||||
|
||||
#writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
augmentation=augmentation,patches=patches)
|
||||
|
||||
provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs,
|
||||
dir_flow_eval_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
augmentation=False,patches=patches)
|
||||
dir_flow_train_imgs=os.path.join(dir_train_flowing,'images/')
|
||||
dir_flow_train_labels=os.path.join(dir_train_flowing,'labels/')
|
||||
|
||||
if weighted_loss:
|
||||
weights=np.zeros(n_classes)
|
||||
for obj in os.listdir(dir_seg):
|
||||
label_obj=cv2.imread(dir_seg+'/'+obj)
|
||||
label_obj_one_hot=get_one_hot( label_obj,label_obj.shape[0],label_obj.shape[1],n_classes)
|
||||
weights+=(label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
dir_flow_eval_imgs=os.path.join(dir_eval_flowing,'images/')
|
||||
dir_flow_eval_labels=os.path.join(dir_eval_flowing,'labels/')
|
||||
|
||||
if os.path.isdir(dir_train_flowing):
|
||||
os.system('rm -rf '+dir_train_flowing)
|
||||
os.makedirs(dir_train_flowing)
|
||||
else:
|
||||
os.makedirs(dir_train_flowing)
|
||||
|
||||
if os.path.isdir(dir_eval_flowing):
|
||||
os.system('rm -rf '+dir_eval_flowing)
|
||||
os.makedirs(dir_eval_flowing)
|
||||
else:
|
||||
os.makedirs(dir_eval_flowing)
|
||||
|
||||
|
||||
weights=1.00/weights
|
||||
os.mkdir(dir_flow_train_imgs)
|
||||
os.mkdir(dir_flow_train_labels)
|
||||
|
||||
weights=weights/float(np.sum(weights))
|
||||
weights=weights/float(np.min(weights))
|
||||
weights=weights/float(np.sum(weights))
|
||||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
|
||||
#set the gpu configuration
|
||||
configuration()
|
||||
|
||||
|
||||
|
||||
|
||||
#writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(dir_img,dir_seg,dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
rotation_not_90,thetha,scaling_flip,
|
||||
augmentation=augmentation,patches=patches)
|
||||
|
||||
#get our model.
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
provide_patches(dir_img_val,dir_seg_val,dir_flow_eval_imgs,
|
||||
dir_flow_eval_labels,
|
||||
input_height,input_width,blur_k,blur_aug,
|
||||
flip_aug,binarization,scaling,scales,flip_index,
|
||||
scaling_bluring,scaling_binarization,rotation,
|
||||
rotation_not_90,thetha,scaling_flip,
|
||||
augmentation=False,patches=patches)
|
||||
|
||||
|
||||
continue_train = False
|
||||
|
||||
if continue_train:
|
||||
model_dir_start = '/home/vahid/Documents/struktur_full_data/output_multi/model_0.h5'
|
||||
model = load_model (model_dir_start, compile = True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
index_start = 1
|
||||
else:
|
||||
#get our model.
|
||||
index_start = 0
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
|
||||
if not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
|
||||
mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5',
|
||||
save_weights_only=True, period=1)
|
||||
|
||||
|
||||
#model.compile(loss='categorical_crossentropy',
|
||||
#optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer = Adam(lr=learning_rate),metrics=['accuracy'])
|
||||
|
||||
#generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs,dir_flow_train_labels, batch_size = n_batch,
|
||||
|
@ -166,20 +171,20 @@ def run(n_classes,n_epochs,input_height,
|
|||
val_gen = data_gen(dir_flow_eval_imgs,dir_flow_eval_labels, batch_size = n_batch,
|
||||
input_height=input_height, input_width=input_width,n_classes=n_classes )
|
||||
|
||||
for i in range(index_start, n_epochs+index_start):
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
model.save(dir_output+'/'+'model_'+str(i)+'.h5')
|
||||
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs))/n_batch)-1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=n_epochs)
|
||||
|
||||
|
||||
|
||||
os.system('rm -rf '+dir_train_flowing)
|
||||
os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
model.save(dir_output+'/'+'model'+'.h5')
|
||||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue