From a7e1f255f3468d113ed748d42c338c0a1c7e3a1f Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 8 May 2024 14:47:16 +0200 Subject: [PATCH] Update train.py avoid ensembling if no model weights met the threshold f1 score in the case of classification --- train.py | 46 +++++++++++++--------------------------------- 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/train.py b/train.py index 28363d2..78974d3 100644 --- a/train.py +++ b/train.py @@ -268,36 +268,26 @@ def run(_config, n_classes, n_epochs, input_height, list_classes = list(classification_classes_name.values()) testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes) - - #print(testY.shape, testY) y_tot=np.zeros((testX.shape[0],n_classes)) - indexer=0 score_best=[] score_best.append(0) num_rows = return_number_of_total_training_data(dir_train) - weights=[] for i in range(n_epochs): - #history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights) - history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights) + history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=1)#,class_weight=weights) y_pr_class = [] for jj in range(testY.shape[0]): y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0) y_pr_ind= np.argmax(y_pr,axis=1) - #print(y_pr_ind, 'y_pr_ind') y_pr_class.append(y_pr_ind) - y_pr_class = np.array(y_pr_class) - #model.save('./models_save/model_'+str(i)+'.h5') - #y_pr_class=np.argmax(y_pr,axis=1) f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro') - print(i,f1score) if f1score>score_best[0]: @@ -306,30 +296,20 @@ def run(_config, n_classes, n_epochs, input_height, if f1score > f1_threshold_classification: weights.append(model.get_weights() ) - y_tot=y_tot+y_pr - indexer+=1 - y_tot=y_tot/float(indexer) - - new_weights=list() - - for weights_list_tuple in zip(*weights): - new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] ) - - new_weights = [np.array(x) for x in new_weights] - - model_weight_averaged=tf.keras.models.clone_model(model) - - model_weight_averaged.set_weights(new_weights) - - #y_tot_end=np.argmax(y_tot,axis=1) - #print(f1_score(np.argmax(testY,axis=1), y_tot_end, average='macro')) - - ##best_model.save('model_taza.h5') - model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) - with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON + if len(weights) >= 1: + new_weights=list() + for weights_list_tuple in zip(*weights): + new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] ) + + new_weights = [np.array(x) for x in new_weights] + model_weight_averaged=tf.keras.models.clone_model(model) + model_weight_averaged.set_weights(new_weights) + + model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg')) + with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp: + json.dump(_config, fp) # encode dict into JSON with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp: json.dump(_config, fp) # encode dict into JSON