mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
Update train.py
avoid ensembling if no model weights met the threshold f1 score in the case of classification
This commit is contained in:
parent
ce1108aca0
commit
a7e1f255f3
1 changed files with 13 additions and 33 deletions
24
train.py
24
train.py
|
@ -269,35 +269,25 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
list_classes = list(classification_classes_name.values())
|
||||
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes)
|
||||
|
||||
#print(testY.shape, testY)
|
||||
|
||||
y_tot=np.zeros((testX.shape[0],n_classes))
|
||||
indexer=0
|
||||
|
||||
score_best=[]
|
||||
score_best.append(0)
|
||||
|
||||
num_rows = return_number_of_total_training_data(dir_train)
|
||||
|
||||
weights=[]
|
||||
|
||||
for i in range(n_epochs):
|
||||
#history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights)
|
||||
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
|
||||
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=1)#,class_weight=weights)
|
||||
|
||||
y_pr_class = []
|
||||
for jj in range(testY.shape[0]):
|
||||
y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0)
|
||||
y_pr_ind= np.argmax(y_pr,axis=1)
|
||||
#print(y_pr_ind, 'y_pr_ind')
|
||||
y_pr_class.append(y_pr_ind)
|
||||
|
||||
|
||||
y_pr_class = np.array(y_pr_class)
|
||||
#model.save('./models_save/model_'+str(i)+'.h5')
|
||||
#y_pr_class=np.argmax(y_pr,axis=1)
|
||||
f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro')
|
||||
|
||||
print(i,f1score)
|
||||
|
||||
if f1score>score_best[0]:
|
||||
|
@ -306,27 +296,17 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
|
||||
if f1score > f1_threshold_classification:
|
||||
weights.append(model.get_weights() )
|
||||
y_tot=y_tot+y_pr
|
||||
|
||||
indexer+=1
|
||||
y_tot=y_tot/float(indexer)
|
||||
|
||||
|
||||
if len(weights) >= 1:
|
||||
new_weights=list()
|
||||
|
||||
for weights_list_tuple in zip(*weights):
|
||||
new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] )
|
||||
|
||||
new_weights = [np.array(x) for x in new_weights]
|
||||
|
||||
model_weight_averaged=tf.keras.models.clone_model(model)
|
||||
|
||||
model_weight_averaged.set_weights(new_weights)
|
||||
|
||||
#y_tot_end=np.argmax(y_tot,axis=1)
|
||||
#print(f1_score(np.argmax(testY,axis=1), y_tot_end, average='macro'))
|
||||
|
||||
##best_model.save('model_taza.h5')
|
||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue