mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
first working update of branch
This commit is contained in:
parent
02b1436f39
commit
d27647a0f1
4 changed files with 452 additions and 151 deletions
132
train.py
132
train.py
|
@ -10,6 +10,7 @@ from utils import *
|
|||
from metrics import *
|
||||
from tensorflow.keras.models import load_model
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
|
||||
def configuration():
|
||||
|
@ -42,9 +43,13 @@ def config_params():
|
|||
learning_rate = 1e-4 # Set the learning rate.
|
||||
patches = False # Divides input image into smaller patches (input size of the model) when set to true. For the model to see the full image, like page extraction, set this to false.
|
||||
augmentation = False # To apply any kind of augmentation, this parameter must be set to true.
|
||||
flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in train.py.
|
||||
blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in train.py.
|
||||
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in train.py.
|
||||
flip_aug = False # If true, different types of flipping will be applied to the image. Types of flips are defined with "flip_index" in config_params.json.
|
||||
blur_aug = False # If true, different types of blurring will be applied to the image. Types of blur are defined with "blur_k" in config_params.json.
|
||||
padding_white = False # If true, white padding will be applied to the image.
|
||||
padding_black = False # If true, black padding will be applied to the image.
|
||||
scaling = False # If true, scaling will be applied to the image. The amount of scaling is defined with "scales" in config_params.json.
|
||||
degrading = False # If true, degrading will be applied to the image. The amount of degrading is defined with "degrade_scales" in config_params.json.
|
||||
brightening = False # If true, brightening will be applied to the image. The amount of brightening is defined with "brightness" in config_params.json.
|
||||
binarization = False # If true, Otsu thresholding will be applied to augment the input with binarized images.
|
||||
dir_train = None # Directory of training dataset with subdirectories having the names "images" and "labels".
|
||||
dir_eval = None # Directory of validation dataset with subdirectories having the names "images" and "labels".
|
||||
|
@ -52,13 +57,18 @@ def config_params():
|
|||
pretraining = False # Set to true to load pretrained weights of ResNet50 encoder.
|
||||
scaling_bluring = False # If true, a combination of scaling and blurring will be applied to the image.
|
||||
scaling_binarization = False # If true, a combination of scaling and binarization will be applied to the image.
|
||||
scaling_brightness = False # If true, a combination of scaling and brightening will be applied to the image.
|
||||
scaling_flip = False # If true, a combination of scaling and flipping will be applied to the image.
|
||||
thetha = [10, -10] # Rotate image by these angles for augmentation.
|
||||
blur_k = ['blur', 'gauss', 'median'] # Blur image for augmentation.
|
||||
scales = [0.5, 2] # Scale patches for augmentation.
|
||||
flip_index = [0, 1, -1] # Flip image for augmentation.
|
||||
thetha = None # Rotate image by these angles for augmentation.
|
||||
blur_k = None # Blur image for augmentation.
|
||||
scales = None # Scale patches for augmentation.
|
||||
degrade_scales = None # Degrade image for augmentation.
|
||||
brightness = None # Brighten image for augmentation.
|
||||
flip_index = None # Flip image for augmentation.
|
||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||
transformer_patchsize = None # Patch size of vision transformer patches.
|
||||
num_patches_xy = None # Number of patches for vision transformer.
|
||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||
|
@ -66,15 +76,19 @@ def config_params():
|
|||
|
||||
|
||||
@ex.automain
|
||||
def run(n_classes, n_epochs, input_height,
|
||||
def run(_config, n_classes, n_epochs, input_height,
|
||||
input_width, weight_decay, weighted_loss,
|
||||
index_start, dir_of_start_model, is_loss_soft_dice,
|
||||
n_batch, patches, augmentation, flip_aug,
|
||||
blur_aug, scaling, binarization,
|
||||
blur_k, scales, dir_train, data_is_provided,
|
||||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip, continue_training,
|
||||
flip_index, dir_eval, dir_output, pretraining, learning_rate):
|
||||
blur_aug, padding_white, padding_black, scaling, degrading,
|
||||
brightening, binarization, blur_k, scales, degrade_scales,
|
||||
brightness, dir_train, data_is_provided, scaling_bluring,
|
||||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||
thetha, scaling_flip, continue_training, transformer_patchsize,
|
||||
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate):
|
||||
|
||||
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
||||
if data_is_provided:
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
@ -121,23 +135,28 @@ def run(n_classes, n_epochs, input_height,
|
|||
|
||||
# set the gpu configuration
|
||||
configuration()
|
||||
|
||||
imgs_list=np.array(os.listdir(dir_img))
|
||||
segs_list=np.array(os.listdir(dir_seg))
|
||||
|
||||
imgs_list_test=np.array(os.listdir(dir_img_val))
|
||||
segs_list_test=np.array(os.listdir(dir_seg_val))
|
||||
|
||||
# writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels,
|
||||
input_height, input_width, blur_k, blur_aug,
|
||||
flip_aug, binarization, scaling, scales, flip_index,
|
||||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip,
|
||||
augmentation=augmentation, patches=patches)
|
||||
|
||||
provide_patches(dir_img_val, dir_seg_val, dir_flow_eval_imgs,
|
||||
dir_flow_eval_labels,
|
||||
input_height, input_width, blur_k, blur_aug,
|
||||
flip_aug, binarization, scaling, scales, flip_index,
|
||||
scaling_bluring, scaling_binarization, rotation,
|
||||
rotation_not_90, thetha, scaling_flip,
|
||||
augmentation=False, patches=patches)
|
||||
provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels, input_height, input_width, blur_k,
|
||||
blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation,
|
||||
patches=patches)
|
||||
|
||||
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
|
||||
dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width,
|
||||
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches)
|
||||
|
||||
if weighted_loss:
|
||||
weights = np.zeros(n_classes)
|
||||
|
@ -166,38 +185,50 @@ def run(n_classes, n_epochs, input_height,
|
|||
weights = weights / float(np.sum(weights))
|
||||
|
||||
if continue_training:
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True,
|
||||
custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True)
|
||||
if model_name=='resnet50_unet':
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
else:
|
||||
# get our model.
|
||||
index_start = 0
|
||||
model = resnet50_unet(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||
|
||||
# if you want to see the model structure just uncomment model summary.
|
||||
# model.summary()
|
||||
if model_name=='resnet50_unet':
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
||||
|
||||
# generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
|
||||
|
||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
|
@ -205,9 +236,12 @@ def run(n_classes, n_epochs, input_height,
|
|||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
model.save(dir_output + '/' + 'model_' + str(i))
|
||||
model.save(dir_output+'/'+'model_'+str(i))
|
||||
|
||||
with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
# os.system('rm -rf '+dir_train_flowing)
|
||||
# os.system('rm -rf '+dir_eval_flowing)
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
#os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
# model.save(dir_output+'/'+'model'+'.h5')
|
||||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue