mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
Update train.py
avoid ensembling if no model weights met the threshold f1 score in the case of classification
This commit is contained in:
parent
ce1108aca0
commit
a7e1f255f3
1 changed files with 13 additions and 33 deletions
46
train.py
46
train.py
|
@ -268,36 +268,26 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
|
|
||||||
list_classes = list(classification_classes_name.values())
|
list_classes = list(classification_classes_name.values())
|
||||||
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes)
|
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes)
|
||||||
|
|
||||||
#print(testY.shape, testY)
|
|
||||||
|
|
||||||
y_tot=np.zeros((testX.shape[0],n_classes))
|
y_tot=np.zeros((testX.shape[0],n_classes))
|
||||||
indexer=0
|
|
||||||
|
|
||||||
score_best=[]
|
score_best=[]
|
||||||
score_best.append(0)
|
score_best.append(0)
|
||||||
|
|
||||||
num_rows = return_number_of_total_training_data(dir_train)
|
num_rows = return_number_of_total_training_data(dir_train)
|
||||||
|
|
||||||
weights=[]
|
weights=[]
|
||||||
|
|
||||||
for i in range(n_epochs):
|
for i in range(n_epochs):
|
||||||
#history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights)
|
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=1)#,class_weight=weights)
|
||||||
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
|
|
||||||
|
|
||||||
y_pr_class = []
|
y_pr_class = []
|
||||||
for jj in range(testY.shape[0]):
|
for jj in range(testY.shape[0]):
|
||||||
y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0)
|
y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0)
|
||||||
y_pr_ind= np.argmax(y_pr,axis=1)
|
y_pr_ind= np.argmax(y_pr,axis=1)
|
||||||
#print(y_pr_ind, 'y_pr_ind')
|
|
||||||
y_pr_class.append(y_pr_ind)
|
y_pr_class.append(y_pr_ind)
|
||||||
|
|
||||||
|
|
||||||
y_pr_class = np.array(y_pr_class)
|
y_pr_class = np.array(y_pr_class)
|
||||||
#model.save('./models_save/model_'+str(i)+'.h5')
|
|
||||||
#y_pr_class=np.argmax(y_pr,axis=1)
|
|
||||||
f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro')
|
f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro')
|
||||||
|
|
||||||
print(i,f1score)
|
print(i,f1score)
|
||||||
|
|
||||||
if f1score>score_best[0]:
|
if f1score>score_best[0]:
|
||||||
|
@ -306,30 +296,20 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
|
|
||||||
if f1score > f1_threshold_classification:
|
if f1score > f1_threshold_classification:
|
||||||
weights.append(model.get_weights() )
|
weights.append(model.get_weights() )
|
||||||
y_tot=y_tot+y_pr
|
|
||||||
|
|
||||||
indexer+=1
|
|
||||||
y_tot=y_tot/float(indexer)
|
|
||||||
|
|
||||||
|
if len(weights) >= 1:
|
||||||
new_weights=list()
|
new_weights=list()
|
||||||
|
for weights_list_tuple in zip(*weights):
|
||||||
for weights_list_tuple in zip(*weights):
|
new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] )
|
||||||
new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] )
|
|
||||||
|
new_weights = [np.array(x) for x in new_weights]
|
||||||
new_weights = [np.array(x) for x in new_weights]
|
model_weight_averaged=tf.keras.models.clone_model(model)
|
||||||
|
model_weight_averaged.set_weights(new_weights)
|
||||||
model_weight_averaged=tf.keras.models.clone_model(model)
|
|
||||||
|
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||||
model_weight_averaged.set_weights(new_weights)
|
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
#y_tot_end=np.argmax(y_tot,axis=1)
|
|
||||||
#print(f1_score(np.argmax(testY,axis=1), y_tot_end, average='macro'))
|
|
||||||
|
|
||||||
##best_model.save('model_taza.h5')
|
|
||||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
|
||||||
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
|
||||||
|
|
||||||
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue