mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
integrating first working classification training model
This commit is contained in:
parent
d27647a0f1
commit
dbb84507ed
5 changed files with 419 additions and 158 deletions
|
@ -1,13 +1,15 @@
|
|||
{
|
||||
"model_name" : "hybrid_transformer_cnn",
|
||||
"model_name" : "resnet50_unet",
|
||||
"task": "classification",
|
||||
"n_classes" : 2,
|
||||
"n_epochs" : 2,
|
||||
"input_height" : 448,
|
||||
"input_width" : 448,
|
||||
"n_epochs" : 7,
|
||||
"input_height" : 224,
|
||||
"input_width" : 224,
|
||||
"weight_decay" : 1e-6,
|
||||
"n_batch" : 2,
|
||||
"n_batch" : 6,
|
||||
"learning_rate": 1e-4,
|
||||
"patches" : true,
|
||||
"f1_threshold_classification": 0.8,
|
||||
"patches" : false,
|
||||
"pretraining" : true,
|
||||
"augmentation" : false,
|
||||
"flip_aug" : false,
|
||||
|
@ -33,7 +35,7 @@
|
|||
"weighted_loss": false,
|
||||
"is_loss_soft_dice": false,
|
||||
"data_is_provided": false,
|
||||
"dir_train": "/train",
|
||||
"dir_eval": "/eval",
|
||||
"dir_output": "/out"
|
||||
"dir_train": "/home/vahid/Downloads/image_classification_data/train",
|
||||
"dir_eval": "/home/vahid/Downloads/image_classification_data/eval",
|
||||
"dir_output": "/home/vahid/Downloads/image_classification_data/output"
|
||||
}
|
||||
|
|
69
models.py
69
models.py
|
@ -400,7 +400,7 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_
|
|||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
model = keras.Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||
model = Model(inputs, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
num_patches = x.shape[1]*x.shape[2]
|
||||
patches = Patches(patch_size)(x)
|
||||
|
@ -468,6 +468,71 @@ def vit_resnet50_unet(n_classes,patch_size, num_patches, input_height=224,input_
|
|||
o = (BatchNormalization(axis=bn_axis))(o)
|
||||
o = (Activation('softmax'))(o)
|
||||
|
||||
model = keras.Model(inputs=inputs, outputs=o)
|
||||
model = Model(inputs=inputs, outputs=o)
|
||||
|
||||
return model
|
||||
|
||||
def resnet50_classifier(n_classes,input_height=224,input_width=224,weight_decay=1e-6,pretraining=False):
|
||||
include_top=True
|
||||
assert input_height%32 == 0
|
||||
assert input_width%32 == 0
|
||||
|
||||
|
||||
img_input = Input(shape=(input_height,input_width , 3 ))
|
||||
|
||||
if IMAGE_ORDERING == 'channels_last':
|
||||
bn_axis = 3
|
||||
else:
|
||||
bn_axis = 1
|
||||
|
||||
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
|
||||
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, strides=(2, 2),kernel_regularizer=l2(weight_decay), name='conv1')(x)
|
||||
f1 = x
|
||||
|
||||
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
|
||||
x = Activation('relu')(x)
|
||||
x = MaxPooling2D((3, 3) , data_format=IMAGE_ORDERING , strides=(2, 2))(x)
|
||||
|
||||
|
||||
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
||||
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
||||
f2 = one_side_pad(x )
|
||||
|
||||
|
||||
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
||||
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
||||
f3 = x
|
||||
|
||||
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
||||
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
||||
f4 = x
|
||||
|
||||
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
||||
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
||||
f5 = x
|
||||
|
||||
if pretraining:
|
||||
Model(img_input, x).load_weights(resnet50_Weights_path)
|
||||
|
||||
x = AveragePooling2D((7, 7), name='avg_pool')(x)
|
||||
x = Flatten()(x)
|
||||
|
||||
##
|
||||
x = Dense(256, activation='relu', name='fc512')(x)
|
||||
x=Dropout(0.2)(x)
|
||||
##
|
||||
x = Dense(n_classes, activation='softmax', name='fc1000')(x)
|
||||
model = Model(img_input, x)
|
||||
|
||||
|
||||
|
||||
|
||||
return model
|
||||
|
|
|
@ -6,3 +6,4 @@ tqdm
|
|||
imutils
|
||||
numpy
|
||||
scipy
|
||||
scikit-learn
|
||||
|
|
374
train.py
374
train.py
|
@ -11,6 +11,7 @@ from metrics import *
|
|||
from tensorflow.keras.models import load_model
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
|
||||
def configuration():
|
||||
|
@ -73,6 +74,8 @@ def config_params():
|
|||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||
weighted_loss = False # Use weighted categorical cross entropy as loss fucntion. When set to true, "is_loss_soft_dice" must be false.
|
||||
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
||||
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
|
||||
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||
|
||||
|
||||
@ex.automain
|
||||
|
@ -86,162 +89,239 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||
thetha, scaling_flip, continue_training, transformer_patchsize,
|
||||
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate):
|
||||
pretraining, learning_rate, task, f1_threshold_classification):
|
||||
|
||||
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
||||
if data_is_provided:
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
||||
dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images')
|
||||
dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels')
|
||||
|
||||
dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images')
|
||||
dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels')
|
||||
|
||||
configuration()
|
||||
|
||||
else:
|
||||
dir_img, dir_seg = get_dirs_or_files(dir_train)
|
||||
dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval)
|
||||
|
||||
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
||||
dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images/')
|
||||
dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels/')
|
||||
|
||||
dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images/')
|
||||
dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels/')
|
||||
|
||||
if os.path.isdir(dir_train_flowing):
|
||||
os.system('rm -rf ' + dir_train_flowing)
|
||||
os.makedirs(dir_train_flowing)
|
||||
else:
|
||||
os.makedirs(dir_train_flowing)
|
||||
|
||||
if os.path.isdir(dir_eval_flowing):
|
||||
os.system('rm -rf ' + dir_eval_flowing)
|
||||
os.makedirs(dir_eval_flowing)
|
||||
else:
|
||||
os.makedirs(dir_eval_flowing)
|
||||
|
||||
os.mkdir(dir_flow_train_imgs)
|
||||
os.mkdir(dir_flow_train_labels)
|
||||
|
||||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
# set the gpu configuration
|
||||
configuration()
|
||||
if task == "segmentation":
|
||||
|
||||
imgs_list=np.array(os.listdir(dir_img))
|
||||
segs_list=np.array(os.listdir(dir_seg))
|
||||
|
||||
imgs_list_test=np.array(os.listdir(dir_img_val))
|
||||
segs_list_test=np.array(os.listdir(dir_seg_val))
|
||||
|
||||
# writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels, input_height, input_width, blur_k,
|
||||
blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation,
|
||||
patches=patches)
|
||||
|
||||
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
|
||||
dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width,
|
||||
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches)
|
||||
|
||||
if weighted_loss:
|
||||
weights = np.zeros(n_classes)
|
||||
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
||||
if data_is_provided:
|
||||
for obj in os.listdir(dir_flow_train_labels):
|
||||
try:
|
||||
label_obj = cv2.imread(dir_flow_train_labels + '/' + obj)
|
||||
label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes)
|
||||
weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
except:
|
||||
pass
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
||||
dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images')
|
||||
dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels')
|
||||
|
||||
dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images')
|
||||
dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels')
|
||||
|
||||
configuration()
|
||||
|
||||
else:
|
||||
dir_img, dir_seg = get_dirs_or_files(dir_train)
|
||||
dir_img_val, dir_seg_val = get_dirs_or_files(dir_eval)
|
||||
|
||||
for obj in os.listdir(dir_seg):
|
||||
try:
|
||||
label_obj = cv2.imread(dir_seg + '/' + obj)
|
||||
label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes)
|
||||
weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
except:
|
||||
pass
|
||||
# make first a directory in output for both training and evaluations in order to flow data from these directories.
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||
|
||||
weights = 1.00 / weights
|
||||
dir_flow_train_imgs = os.path.join(dir_train_flowing, 'images/')
|
||||
dir_flow_train_labels = os.path.join(dir_train_flowing, 'labels/')
|
||||
|
||||
weights = weights / float(np.sum(weights))
|
||||
weights = weights / float(np.min(weights))
|
||||
weights = weights / float(np.sum(weights))
|
||||
dir_flow_eval_imgs = os.path.join(dir_eval_flowing, 'images/')
|
||||
dir_flow_eval_labels = os.path.join(dir_eval_flowing, 'labels/')
|
||||
|
||||
if continue_training:
|
||||
if model_name=='resnet50_unet':
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
else:
|
||||
index_start = 0
|
||||
if model_name=='resnet50_unet':
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
if os.path.isdir(dir_train_flowing):
|
||||
os.system('rm -rf ' + dir_train_flowing)
|
||||
os.makedirs(dir_train_flowing)
|
||||
else:
|
||||
os.makedirs(dir_train_flowing)
|
||||
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
if os.path.isdir(dir_eval_flowing):
|
||||
os.system('rm -rf ' + dir_eval_flowing)
|
||||
os.makedirs(dir_eval_flowing)
|
||||
else:
|
||||
os.makedirs(dir_eval_flowing)
|
||||
|
||||
os.mkdir(dir_flow_train_imgs)
|
||||
os.mkdir(dir_flow_train_labels)
|
||||
|
||||
os.mkdir(dir_flow_eval_imgs)
|
||||
os.mkdir(dir_flow_eval_labels)
|
||||
|
||||
# set the gpu configuration
|
||||
configuration()
|
||||
|
||||
imgs_list=np.array(os.listdir(dir_img))
|
||||
segs_list=np.array(os.listdir(dir_seg))
|
||||
|
||||
imgs_list_test=np.array(os.listdir(dir_img_val))
|
||||
segs_list_test=np.array(os.listdir(dir_seg_val))
|
||||
|
||||
# writing patches into a sub-folder in order to be flowed from directory.
|
||||
provide_patches(imgs_list, segs_list, dir_img, dir_seg, dir_flow_train_imgs,
|
||||
dir_flow_train_labels, input_height, input_width, blur_k,
|
||||
blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=augmentation,
|
||||
patches=patches)
|
||||
|
||||
provide_patches(imgs_list_test, segs_list_test, dir_img_val, dir_seg_val,
|
||||
dir_flow_eval_imgs, dir_flow_eval_labels, input_height, input_width,
|
||||
blur_k, blur_aug, padding_white, padding_black, flip_aug, binarization,
|
||||
scaling, degrading, brightening, scales, degrade_scales, brightness,
|
||||
flip_index, scaling_bluring, scaling_brightness, scaling_binarization,
|
||||
rotation, rotation_not_90, thetha, scaling_flip, augmentation=False, patches=patches)
|
||||
|
||||
if weighted_loss:
|
||||
weights = np.zeros(n_classes)
|
||||
if data_is_provided:
|
||||
for obj in os.listdir(dir_flow_train_labels):
|
||||
try:
|
||||
label_obj = cv2.imread(dir_flow_train_labels + '/' + obj)
|
||||
label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes)
|
||||
weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
|
||||
for obj in os.listdir(dir_seg):
|
||||
try:
|
||||
label_obj = cv2.imread(dir_seg + '/' + obj)
|
||||
label_obj_one_hot = get_one_hot(label_obj, label_obj.shape[0], label_obj.shape[1], n_classes)
|
||||
weights += (label_obj_one_hot.sum(axis=0)).sum(axis=0)
|
||||
except:
|
||||
pass
|
||||
|
||||
weights = 1.00 / weights
|
||||
|
||||
weights = weights / float(np.sum(weights))
|
||||
weights = weights / float(np.min(weights))
|
||||
weights = weights / float(np.sum(weights))
|
||||
|
||||
if continue_training:
|
||||
if model_name=='resnet50_unet':
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
if is_loss_soft_dice:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||
if weighted_loss:
|
||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
else:
|
||||
index_start = 0
|
||||
if model_name=='resnet50_unet':
|
||||
model = resnet50_unet(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
elif model_name=='hybrid_transformer_cnn':
|
||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
#model.summary()
|
||||
|
||||
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
||||
# generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
|
||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
model.save(dir_output+'/'+'model_'+str(i))
|
||||
|
||||
with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
#os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
elif task=='classification':
|
||||
configuration()
|
||||
model = resnet50_classifier(n_classes, input_height, input_width,weight_decay,pretraining)
|
||||
|
||||
opt_adam = Adam(learning_rate=0.001)
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if is_loss_soft_dice:
|
||||
model.compile(loss=soft_dice_loss,
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
if weighted_loss:
|
||||
model.compile(loss=weighted_categorical_crossentropy(weights),
|
||||
optimizer=Adam(lr=learning_rate), metrics=['accuracy'])
|
||||
|
||||
# generating train and evaluation data
|
||||
train_gen = data_gen(dir_flow_train_imgs, dir_flow_train_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
|
||||
input_height=input_height, input_width=input_width, n_classes=n_classes)
|
||||
|
||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit_generator(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
model.save(dir_output+'/'+'model_'+str(i))
|
||||
|
||||
with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
optimizer = opt_adam,metrics=['accuracy'])
|
||||
|
||||
#os.system('rm -rf '+dir_train_flowing)
|
||||
#os.system('rm -rf '+dir_eval_flowing)
|
||||
|
||||
#model.save(dir_output+'/'+'model'+'.h5')
|
||||
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes)
|
||||
|
||||
#print(testY.shape, testY)
|
||||
|
||||
y_tot=np.zeros((testX.shape[0],n_classes))
|
||||
indexer=0
|
||||
|
||||
score_best=[]
|
||||
score_best.append(0)
|
||||
|
||||
num_rows = return_number_of_total_training_data(dir_train)
|
||||
|
||||
weights=[]
|
||||
|
||||
for i in range(n_epochs):
|
||||
#history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights)
|
||||
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
|
||||
|
||||
y_pr_class = []
|
||||
for jj in range(testY.shape[0]):
|
||||
y_pr=model.predict(testX[jj,:,:,:].reshape(1,input_height,input_width,3), verbose=0)
|
||||
y_pr_ind= np.argmax(y_pr,axis=1)
|
||||
#print(y_pr_ind, 'y_pr_ind')
|
||||
y_pr_class.append(y_pr_ind)
|
||||
|
||||
|
||||
y_pr_class = np.array(y_pr_class)
|
||||
#model.save('./models_save/model_'+str(i)+'.h5')
|
||||
#y_pr_class=np.argmax(y_pr,axis=1)
|
||||
f1score=f1_score(np.argmax(testY,axis=1), y_pr_class, average='macro')
|
||||
|
||||
print(i,f1score)
|
||||
|
||||
if f1score>score_best[0]:
|
||||
score_best[0]=f1score
|
||||
model.save(os.path.join(dir_output,'model_best'))
|
||||
|
||||
|
||||
##best_model=keras.models.clone_model(model)
|
||||
##best_model.build()
|
||||
##best_model.set_weights(model.get_weights())
|
||||
if f1score > f1_threshold_classification:
|
||||
weights.append(model.get_weights() )
|
||||
y_tot=y_tot+y_pr
|
||||
|
||||
indexer+=1
|
||||
y_tot=y_tot/float(indexer)
|
||||
|
||||
|
||||
new_weights=list()
|
||||
|
||||
for weights_list_tuple in zip(*weights):
|
||||
new_weights.append( [np.array(weights_).mean(axis=0) for weights_ in zip(*weights_list_tuple)] )
|
||||
|
||||
new_weights = [np.array(x) for x in new_weights]
|
||||
|
||||
model_weight_averaged=tf.keras.models.clone_model(model)
|
||||
|
||||
model_weight_averaged.set_weights(new_weights)
|
||||
|
||||
#y_tot_end=np.argmax(y_tot,axis=1)
|
||||
#print(f1_score(np.argmax(testY,axis=1), y_tot_end, average='macro'))
|
||||
|
||||
##best_model.save('model_taza.h5')
|
||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||
|
||||
|
|
113
utils.py
113
utils.py
|
@ -8,6 +8,119 @@ import random
|
|||
from tqdm import tqdm
|
||||
import imutils
|
||||
import math
|
||||
from tensorflow.keras.utils import to_categorical
|
||||
|
||||
|
||||
def return_number_of_total_training_data(path_classes):
|
||||
sub_classes = os.listdir(path_classes)
|
||||
n_tot = 0
|
||||
for sub_c in sub_classes:
|
||||
sub_files = os.listdir(os.path.join(path_classes,sub_c))
|
||||
n_tot = n_tot + len(sub_files)
|
||||
return n_tot
|
||||
|
||||
|
||||
|
||||
def generate_data_from_folder_evaluation(path_classes, height, width, n_classes):
|
||||
sub_classes = os.listdir(path_classes)
|
||||
#n_classes = len(sub_classes)
|
||||
all_imgs = []
|
||||
labels = []
|
||||
dicts =dict()
|
||||
indexer= 0
|
||||
for sub_c in sub_classes:
|
||||
sub_files = os.listdir(os.path.join(path_classes,sub_c ))
|
||||
sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files]
|
||||
#print( os.listdir(os.path.join(path_classes,sub_c )) )
|
||||
all_imgs = all_imgs + sub_files
|
||||
sub_labels = list( np.zeros( len(sub_files) ) +indexer )
|
||||
|
||||
#print( len(sub_labels) )
|
||||
labels = labels + sub_labels
|
||||
dicts[sub_c] = indexer
|
||||
indexer +=1
|
||||
|
||||
|
||||
categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ]
|
||||
ret_x= np.zeros((len(labels), height,width, 3)).astype(np.int16)
|
||||
ret_y= np.zeros((len(labels), n_classes)).astype(np.int16)
|
||||
|
||||
#print(all_imgs)
|
||||
for i in range(len(all_imgs)):
|
||||
row = all_imgs[i]
|
||||
#####img = cv2.imread(row, 0)
|
||||
#####img= resize_image (img, height, width)
|
||||
#####img = img.astype(np.uint16)
|
||||
#####ret_x[i, :,:,0] = img[:,:]
|
||||
#####ret_x[i, :,:,1] = img[:,:]
|
||||
#####ret_x[i, :,:,2] = img[:,:]
|
||||
|
||||
img = cv2.imread(row)
|
||||
img= resize_image (img, height, width)
|
||||
img = img.astype(np.uint16)
|
||||
ret_x[i, :,:] = img[:,:,:]
|
||||
|
||||
ret_y[i, :] = categories[ int( labels[i] ) ][:]
|
||||
|
||||
return ret_x/255., ret_y
|
||||
|
||||
def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes):
|
||||
sub_classes = os.listdir(path_classes)
|
||||
n_classes = len(sub_classes)
|
||||
|
||||
all_imgs = []
|
||||
labels = []
|
||||
dicts =dict()
|
||||
indexer= 0
|
||||
for sub_c in sub_classes:
|
||||
sub_files = os.listdir(os.path.join(path_classes,sub_c ))
|
||||
sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files]
|
||||
#print( os.listdir(os.path.join(path_classes,sub_c )) )
|
||||
all_imgs = all_imgs + sub_files
|
||||
sub_labels = list( np.zeros( len(sub_files) ) +indexer )
|
||||
|
||||
#print( len(sub_labels) )
|
||||
labels = labels + sub_labels
|
||||
dicts[sub_c] = indexer
|
||||
indexer +=1
|
||||
|
||||
ids = np.array(range(len(labels)))
|
||||
random.shuffle(ids)
|
||||
|
||||
shuffled_labels = np.array(labels)[ids]
|
||||
shuffled_files = np.array(all_imgs)[ids]
|
||||
categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ]
|
||||
ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16)
|
||||
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
||||
batchcount = 0
|
||||
while True:
|
||||
for i in range(len(shuffled_files)):
|
||||
row = shuffled_files[i]
|
||||
#print(row)
|
||||
###img = cv2.imread(row, 0)
|
||||
###img= resize_image (img, height, width)
|
||||
###img = img.astype(np.uint16)
|
||||
###ret_x[batchcount, :,:,0] = img[:,:]
|
||||
###ret_x[batchcount, :,:,1] = img[:,:]
|
||||
###ret_x[batchcount, :,:,2] = img[:,:]
|
||||
|
||||
img = cv2.imread(row)
|
||||
img= resize_image (img, height, width)
|
||||
img = img.astype(np.uint16)
|
||||
ret_x[batchcount, :,:,:] = img[:,:,:]
|
||||
|
||||
#print(int(shuffled_labels[i]) )
|
||||
#print( categories[int(shuffled_labels[i])] )
|
||||
ret_y[batchcount, :] = categories[ int( shuffled_labels[i] ) ][:]
|
||||
|
||||
batchcount+=1
|
||||
|
||||
if batchcount>=batchsize:
|
||||
ret_x = ret_x/255.
|
||||
yield (ret_x, ret_y)
|
||||
ret_x= np.zeros((batchsize, height,width, 3)).astype(np.int16)
|
||||
ret_y= np.zeros((batchsize, n_classes)).astype(np.int16)
|
||||
batchcount = 0
|
||||
|
||||
def do_brightening(img_in_dir, factor):
|
||||
im = Image.open(img_in_dir)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue