inference script is added

unifying-training-models
vahidrezanezhad 2 weeks ago
parent 38db3e9289
commit 8d1050ec30

@ -1,12 +1,12 @@
{
"model_name" : "resnet50_unet",
"task": "enhancement",
"n_classes" : 3,
"n_epochs" : 3,
"backbone_type" : "nontransformer",
"task": "classification",
"n_classes" : 2,
"n_epochs" : 20,
"input_height" : 448,
"input_width" : 448,
"weight_decay" : 1e-6,
"n_batch" : 3,
"n_batch" : 6,
"learning_rate": 1e-4,
"f1_threshold_classification": 0.8,
"patches" : true,
@ -21,7 +21,7 @@
"scaling_flip" : false,
"rotation": false,
"rotation_not_90": false,
"num_patches_xy": [28, 28],
"transformer_num_patches_xy": [28, 28],
"transformer_patchsize": 1,
"blur_k" : ["blur","guass","median"],
"scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
@ -29,13 +29,14 @@
"degrade_scales" : [0.2, 0.4],
"flip_index" : [0, 1, -1],
"thetha" : [10, -10],
"classification_classes_name" : {"0":"apple", "1":"orange"},
"continue_training": false,
"index_start" : 0,
"dir_of_start_model" : " ",
"weighted_loss": false,
"is_loss_soft_dice": false,
"data_is_provided": false,
"dir_train": "./training_data_sample_enhancement",
"dir_train": "./train",
"dir_eval": "./eval",
"dir_output": "./out"
"dir_output": "./output"
}

@ -0,0 +1,490 @@
#! /usr/bin/env python3
__version__= '1.0'
import argparse
import sys
import os
import numpy as np
import warnings
import xml.etree.ElementTree as et
import pandas as pd
from tqdm import tqdm
import csv
import cv2
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
import click
import json
from tensorflow.python.keras import backend as tensorflow_backend
with warnings.catch_warnings():
warnings.simplefilter("ignore")
__doc__=\
"""
Tool to load model and predict for given image.
"""
projection_dim = 64
patch_size = 1
num_patches =28*28
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
print(tf.shape(images)[1],'images')
print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
print(patches.shape,patch_dims,'patch_dims')
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class sbb_predict:
def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ):
self.image=image
self.patches=patches
self.save=save
self.model_dir=model
self.ground_truth=ground_truth
self.weights_dir=weights_dir
self.task=task
self.config_params_model=config_params_model
def resize_image(self,img_in,input_height,input_width):
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
def color_images(self,seg):
ann_u=range(self.n_classes)
if len(np.shape(seg))==3:
seg=seg[:,:,0]
seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(np.uint8)
colors=sns.color_palette("hls", self.n_classes)
for c in ann_u:
c=int(c)
segl=(seg==c)
seg_img[:,:,0][seg==c]=c
seg_img[:,:,1][seg==c]=c
seg_img[:,:,2][seg==c]=c
return seg_img
def otsu_copy_binary(self,img):
img_r=np.zeros((img.shape[0],img.shape[1],3))
img1=img[:,:,0]
#print(img.min())
#print(img[:,:,0].min())
#blur = cv2.GaussianBlur(img,(5,5))
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
img_r[:,:,0]=threshold1
img_r[:,:,1]=threshold1
img_r[:,:,2]=threshold1
#img_r=img_r/float(np.max(img_r))*255
return img_r
def otsu_copy(self,img):
img_r=np.zeros((img.shape[0],img.shape[1],3))
#img1=img[:,:,0]
#print(img.min())
#print(img[:,:,0].min())
#blur = cv2.GaussianBlur(img,(5,5))
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold1 = cv2.threshold(img[:,:,0], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold2 = cv2.threshold(img[:,:,1], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold3 = cv2.threshold(img[:,:,2], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
img_r[:,:,0]=threshold1
img_r[:,:,1]=threshold2
img_r[:,:,2]=threshold3
###img_r=img_r/float(np.max(img_r))*255
return img_r
def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6):
axes = tuple(range(1, len(y_pred.shape)-1))
numerator = 2. * K.sum(y_pred * y_true, axes)
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
def weighted_categorical_crossentropy(self,weights=None):
def loss(y_true, y_pred):
labels_floats = tf.cast(y_true, tf.float32)
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
if weights is not None:
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
np.array(weights, dtype=np.float32)[None, None, None])
* labels_floats, axis=-1), 1.0)
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
return tf.reduce_mean(per_pixel_loss)
return self.loss
def IoU(self,Yi,y_predi):
## mean Intersection over Union
## Mean IoU = TP/(FN + TP + FP)
IoUs = []
Nclass = np.unique(Yi)
for c in Nclass:
TP = np.sum( (Yi == c)&(y_predi==c) )
FP = np.sum( (Yi != c)&(y_predi==c) )
FN = np.sum( (Yi == c)&(y_predi != c))
IoU = TP/float(TP + FP + FN)
if self.n_classes>2:
print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU))
IoUs.append(IoU)
if self.n_classes>2:
mIoU = np.mean(IoUs)
print("_________________")
print("Mean IoU: {:4.3f}".format(mIoU))
return mIoU
elif self.n_classes==2:
mIoU = IoUs[1]
print("_________________")
print("IoU: {:4.3f}".format(mIoU))
return mIoU
def start_new_session_and_model(self):
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
#tensorflow.keras.layers.custom_layer = PatchEncoder
#tensorflow.keras.layers.custom_layer = Patches
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
#config = tf.ConfigProto()
#config.gpu_options.allow_growth=True
#self.session = tf.InteractiveSession()
#keras.losses.custom_loss = self.weighted_categorical_crossentropy
#self.model = load_model(self.model_dir , compile=False)
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)
if self.task != 'classification':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
def visualize_model_output(self, prediction, img, task):
if task == "binarization":
prediction = prediction * -1
prediction = prediction + 1
added_image = prediction * 255
else:
unique_classes = np.unique(prediction[:,:,0])
rgb_colors = {'0' : [255, 255, 255],
'1' : [255, 0, 0],
'2' : [255, 125, 0],
'3' : [255, 0, 125],
'4' : [125, 125, 125],
'5' : [125, 125, 0],
'6' : [0, 125, 255],
'7' : [0, 125, 0],
'8' : [125, 125, 125],
'9' : [0, 125, 255],
'10' : [125, 0, 125],
'11' : [0, 255, 0],
'12' : [0, 0, 255],
'13' : [0, 255, 255],
'14' : [255, 125, 125],
'15' : [255, 0, 255]}
output = np.zeros(prediction.shape)
for unq_class in unique_classes:
rgb_class_unique = rgb_colors[str(int(unq_class))]
output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0]
output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
img = self.resize_image(img, output.shape[0], output.shape[1])
output = output.astype(np.int32)
img = img.astype(np.int32)
added_image = cv2.addWeighted(img,0.5,output,0.1,0)
return added_image
def predict(self):
self.start_new_session_and_model()
if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(self.image, 0)
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model.predict(img_in, verbose=0)
index_class = np.argmax(label_p_pred[0])
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
else:
if self.patches:
#def textline_contours(img,input_width,input_height,n_classes,model):
img=cv2.imread(self.image)
self.img_org = np.copy(img)
if img.shape[0] < self.img_height:
img = cv2.resize(img, (img.shape[1], self.img_width), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < self.img_width:
img = cv2.resize(img, (self.img_height, img.shape[0]), interpolation=cv2.INTER_NEAREST)
margin = int(0 * self.img_width)
width_mid = self.img_width - 2 * margin
height_mid = self.img_height - 2 * margin
img = img / float(255.0)
img_h = img.shape[0]
img_w = img.shape[1]
prediction_true = np.zeros((img_h, img_w, 3))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
for i in range(nxf):
for j in range(nyf):
if i == 0:
index_x_d = i * width_mid
index_x_u = index_x_d + self.img_width
else:
index_x_d = i * width_mid
index_x_u = index_x_d + self.img_width
if j == 0:
index_y_d = j * height_mid
index_y_u = index_y_d + self.img_height
else:
index_y_d = j * height_mid
index_y_u = index_y_d + self.img_height
if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - self.img_width
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - self.img_height
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
verbose=0)
if self.task == 'enhancement':
seg = label_p_pred[0, :, :, :]
seg = seg * 255
elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
if i == 0 and j == 0:
seg = seg[0 : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg
elif i == nxf - 1 and j == nyf - 1:
seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - 0]
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg
elif i == 0 and j == nyf - 1:
seg = seg[margin : seg.shape[0] - 0, 0 : seg.shape[1] - margin]
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg
elif i == nxf - 1 and j == 0:
seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - 0]
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg
elif i == 0 and j != 0 and j != nyf - 1:
seg = seg[margin : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg
elif i == nxf - 1 and j != 0 and j != nyf - 1:
seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - 0]
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg
elif i != 0 and i != nxf - 1 and j == 0:
seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - margin]
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg
elif i != 0 and i != nxf - 1 and j == nyf - 1:
seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - margin]
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg
else:
seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - margin]
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg
prediction_true = prediction_true.astype(int)
prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST)
return prediction_true
else:
img=cv2.imread(self.image)
self.img_org = np.copy(img)
width=self.img_width
height=self.img_height
img=img/255.0
img=self.resize_image(img,self.img_height,self.img_width)
label_p_pred=self.model.predict(
img.reshape(1,img.shape[0],img.shape[1],img.shape[2]))
if self.task == 'enhancement':
seg = label_p_pred[0, :, :, :]
seg = seg * 255
elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = seg.astype(int)
prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST)
return prediction_true
def run(self):
res=self.predict()
if self.task == 'classification':
pass
else:
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
cv2.imwrite('./test.png',img_seg_overlayed)
##if self.save!=None:
##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2)
##cv2.imwrite(self.save,img)
###if self.ground_truth!=None:
###gt_img=cv2.imread(self.ground_truth)
###self.IoU(gt_img[:,:,0],res)
##plt.imshow(res)
##plt.show()
@click.command()
@click.option(
"--image",
"-i",
help="image filename",
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"--patches/--no-patches",
"-p/-nop",
is_flag=True,
help="if this parameter set to true, this tool will try to do inference in patches.",
)
@click.option(
"--save",
"-s",
help="save prediction as a png file in current folder.",
)
@click.option(
"--model",
"-m",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
required=True,
)
@click.option(
"--ground_truth/--no-ground_truth",
"-gt/-nogt",
is_flag=True,
help="ground truth directory if you want to see the iou of prediction.",
)
@click.option(
"--model_weights/--no-model_weights",
"-mw/-nomw",
is_flag=True,
help="previous model weights which are saved.",
)
def main(image, model, patches, save, ground_truth, model_weights):
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = 'classification'
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights)
x.run()
if __name__=="__main__":
main()

@ -69,7 +69,7 @@ def config_params():
flip_index = None # Flip image for augmentation.
continue_training = False # Set to true if you would like to continue training an already trained a model.
transformer_patchsize = None # Patch size of vision transformer patches.
num_patches_xy = None # Number of patches for vision transformer.
transformer_num_patches_xy = None # Number of patches for vision transformer.
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
@ -77,6 +77,8 @@ def config_params():
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
classification_classes_name = None # Dictionary of classification classes names.
backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer"
@ex.automain
@ -89,12 +91,12 @@ def run(_config, n_classes, n_epochs, input_height,
brightness, dir_train, data_is_provided, scaling_bluring,
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
thetha, scaling_flip, continue_training, transformer_patchsize,
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification):
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
if task == "segmentation" or "enhancement":
if task == "segmentation" or task == "enhancement":
num_patches = num_patches_xy[0]*num_patches_xy[1]
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
if data_is_provided:
dir_train_flowing = os.path.join(dir_output, 'train')
dir_eval_flowing = os.path.join(dir_output, 'eval')
@ -191,14 +193,14 @@ def run(_config, n_classes, n_epochs, input_height,
weights = weights / float(np.sum(weights))
if continue_training:
if model_name=='resnet50_unet':
if backbone_type=='nontransformer':
if is_loss_soft_dice and task == "segmentation":
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
if weighted_loss and task == "segmentation":
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model , compile=True)
elif model_name=='hybrid_transformer_cnn':
elif backbone_type=='transformer':
if is_loss_soft_dice and task == "segmentation":
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
if weighted_loss and task == "segmentation":
@ -207,9 +209,9 @@ def run(_config, n_classes, n_epochs, input_height,
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
else:
index_start = 0
if model_name=='resnet50_unet':
if backbone_type=='nontransformer':
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
elif model_name=='hybrid_transformer_cnn':
elif backbone_type=='nontransformer':
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
#if you want to see the model structure just uncomment model summary.
@ -246,9 +248,9 @@ def run(_config, n_classes, n_epochs, input_height,
validation_data=val_gen,
validation_steps=1,
epochs=1)
model.save(dir_output+'/'+'model_'+str(i))
model.save(os.path.join(dir_output,'model_'+str(i)))
with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp:
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
#os.system('rm -rf '+dir_train_flowing)
@ -257,14 +259,15 @@ def run(_config, n_classes, n_epochs, input_height,
#model.save(dir_output+'/'+'model'+'.h5')
elif task=='classification':
configuration()
model = resnet50_classifier(n_classes, input_height, input_width,weight_decay,pretraining)
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
opt_adam = Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy',
optimizer = opt_adam,metrics=['accuracy'])
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes)
list_classes = list(classification_classes_name.values())
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes)
#print(testY.shape, testY)
@ -280,7 +283,7 @@ def run(_config, n_classes, n_epochs, input_height,
for i in range(n_epochs):
#history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights)
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
y_pr_class = []
for jj in range(testY.shape[0]):
@ -301,10 +304,6 @@ def run(_config, n_classes, n_epochs, input_height,
score_best[0]=f1score
model.save(os.path.join(dir_output,'model_best'))
##best_model=keras.models.clone_model(model)
##best_model.build()
##best_model.set_weights(model.get_weights())
if f1score > f1_threshold_classification:
weights.append(model.get_weights() )
y_tot=y_tot+y_pr
@ -329,4 +328,9 @@ def run(_config, n_classes, n_epochs, input_height,
##best_model.save('model_taza.h5')
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON

@ -21,14 +21,14 @@ def return_number_of_total_training_data(path_classes):
def generate_data_from_folder_evaluation(path_classes, height, width, n_classes):
sub_classes = os.listdir(path_classes)
def generate_data_from_folder_evaluation(path_classes, height, width, n_classes, list_classes):
#sub_classes = os.listdir(path_classes)
#n_classes = len(sub_classes)
all_imgs = []
labels = []
dicts =dict()
indexer= 0
for sub_c in sub_classes:
#dicts =dict()
#indexer= 0
for indexer, sub_c in enumerate(list_classes):
sub_files = os.listdir(os.path.join(path_classes,sub_c ))
sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files]
#print( os.listdir(os.path.join(path_classes,sub_c )) )
@ -37,8 +37,8 @@ def generate_data_from_folder_evaluation(path_classes, height, width, n_classes)
#print( len(sub_labels) )
labels = labels + sub_labels
dicts[sub_c] = indexer
indexer +=1
#dicts[sub_c] = indexer
#indexer +=1
categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ]
@ -64,15 +64,15 @@ def generate_data_from_folder_evaluation(path_classes, height, width, n_classes)
return ret_x/255., ret_y
def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes):
sub_classes = os.listdir(path_classes)
n_classes = len(sub_classes)
def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes, list_classes):
#sub_classes = os.listdir(path_classes)
#n_classes = len(sub_classes)
all_imgs = []
labels = []
dicts =dict()
indexer= 0
for sub_c in sub_classes:
#dicts =dict()
#indexer= 0
for indexer, sub_c in enumerate(list_classes):
sub_files = os.listdir(os.path.join(path_classes,sub_c ))
sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files]
#print( os.listdir(os.path.join(path_classes,sub_c )) )
@ -81,8 +81,8 @@ def generate_data_from_folder_training(path_classes, batchsize, height, width, n
#print( len(sub_labels) )
labels = labels + sub_labels
dicts[sub_c] = indexer
indexer +=1
#dicts[sub_c] = indexer
#indexer +=1
ids = np.array(range(len(labels)))
random.shuffle(ids)

Loading…
Cancel
Save