Update train.py

avoid ensembling if no model weights met the threshold f1 score in the case of classification
pull/18/head
vahidrezanezhad 8 months ago committed by GitHub
parent ce1108aca0
commit a7e1f255f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -269,35 +269,25 @@ def run(_config, n_classes, n_epochs, input_height,
list_classes = list(classification_classes_name.values())
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes)
#print(testY.shape, testY)
y_tot=np.zeros((testX.shape[0],n_classes))
indexer=0
score_best=[]
score_best.append(0)
num_rows = return_number_of_total_training_data(dir_train)
weights=[]
for i in range(n_epochs):
#history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights)
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=1)#,class_weight=weights)
y_pr_class = []
for jj in range(testY.shape[0]):
y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0)
y_pr_ind= np.argmax(y_pr,axis=1)
#print(y_pr_ind, 'y_pr_ind')
y_pr_class.append(y_pr_ind)
y_pr_class = np.array(y_pr_class)
#model.save('./models_save/model_'+str(i)+'.h5')
#y_pr_class=np.argmax(y_pr,axis=1)
f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro')
print(i,f1score)
if f1score>score_best[0]:
@ -306,30 +296,20 @@ def run(_config, n_classes, n_epochs, input_height,
if f1score > f1_threshold_classification:
weights.append(model.get_weights() )
y_tot=y_tot+y_pr
indexer+=1
y_tot=y_tot/float(indexer)
new_weights=list()
for weights_list_tuple in zip(*weights):
new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] )
if len(weights) >= 1:
new_weights=list()
for weights_list_tuple in zip(*weights):
new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] )
new_weights = [np.array(x) for x in new_weights]
new_weights = [np.array(x) for x in new_weights]
model_weight_averaged=tf.keras.models.clone_model(model)
model_weight_averaged.set_weights(new_weights)
model_weight_averaged=tf.keras.models.clone_model(model)
model_weight_averaged.set_weights(new_weights)
#y_tot_end=np.argmax(y_tot,axis=1)
#print(f1_score(np.argmax(testY,axis=1), y_tot_end, average='macro'))
##best_model.save('model_taza.h5')
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON

Loading…
Cancel
Save