mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
inference script is added
This commit is contained in:
parent
38db3e9289
commit
8d1050ec30
4 changed files with 537 additions and 42 deletions
42
train.py
42
train.py
|
@ -69,7 +69,7 @@ def config_params():
|
|||
flip_index = None # Flip image for augmentation.
|
||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
||||
transformer_patchsize = None # Patch size of vision transformer patches.
|
||||
num_patches_xy = None # Number of patches for vision transformer.
|
||||
transformer_num_patches_xy = None # Number of patches for vision transformer.
|
||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||
|
@ -77,6 +77,8 @@ def config_params():
|
|||
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
||||
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
|
||||
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||
classification_classes_name = None # Dictionary of classification classes names.
|
||||
backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer"
|
||||
|
||||
|
||||
@ex.automain
|
||||
|
@ -89,12 +91,12 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
brightness, dir_train, data_is_provided, scaling_bluring,
|
||||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||
thetha, scaling_flip, continue_training, transformer_patchsize,
|
||||
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification):
|
||||
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
||||
|
||||
if task == "segmentation" or "enhancement":
|
||||
if task == "segmentation" or task == "enhancement":
|
||||
|
||||
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
||||
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
|
||||
if data_is_provided:
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
@ -191,14 +193,14 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
weights = weights / float(np.sum(weights))
|
||||
|
||||
if continue_training:
|
||||
if model_name=='resnet50_unet':
|
||||
if backbone_type=='nontransformer':
|
||||
if is_loss_soft_dice and task == "segmentation":
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss and task == "segmentation":
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
elif backbone_type=='transformer':
|
||||
if is_loss_soft_dice and task == "segmentation":
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss and task == "segmentation":
|
||||
|
@ -207,9 +209,9 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
else:
|
||||
index_start = 0
|
||||
if model_name=='resnet50_unet':
|
||||
if backbone_type=='nontransformer':
|
||||
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
elif backbone_type=='nontransformer':
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
|
@ -246,9 +248,9 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
model.save(dir_output+'/'+'model_'+str(i))
|
||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||
|
||||
with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp:
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
|
@ -257,14 +259,15 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
elif task=='classification':
|
||||
configuration()
|
||||
model = resnet50_classifier(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||
|
||||
opt_adam = Adam(learning_rate=0.001)
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer = opt_adam,metrics=['accuracy'])
|
||||
|
||||
|
||||
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes)
|
||||
|
||||
list_classes = list(classification_classes_name.values())
|
||||
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes)
|
||||
|
||||
#print(testY.shape, testY)
|
||||
|
||||
|
@ -280,7 +283,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
for i in range(n_epochs):
|
||||
#history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights)
|
||||
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
|
||||
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
|
||||
|
||||
y_pr_class = []
|
||||
for jj in range(testY.shape[0]):
|
||||
|
@ -301,10 +304,6 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
score_best[0]=f1score
|
||||
model.save(os.path.join(dir_output,'model_best'))
|
||||
|
||||
|
||||
##best_model=keras.models.clone_model(model)
|
||||
##best_model.build()
|
||||
##best_model.set_weights(model.get_weights())
|
||||
if f1score > f1_threshold_classification:
|
||||
weights.append(model.get_weights() )
|
||||
y_tot=y_tot+y_pr
|
||||
|
@ -329,4 +328,9 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
##best_model.save('model_taza.h5')
|
||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue