mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
inference script is added
This commit is contained in:
parent
38db3e9289
commit
8d1050ec30
4 changed files with 537 additions and 42 deletions
|
@ -1,12 +1,12 @@
|
||||||
{
|
{
|
||||||
"model_name" : "resnet50_unet",
|
"backbone_type" : "nontransformer",
|
||||||
"task": "enhancement",
|
"task": "classification",
|
||||||
"n_classes" : 3,
|
"n_classes" : 2,
|
||||||
"n_epochs" : 3,
|
"n_epochs" : 20,
|
||||||
"input_height" : 448,
|
"input_height" : 448,
|
||||||
"input_width" : 448,
|
"input_width" : 448,
|
||||||
"weight_decay" : 1e-6,
|
"weight_decay" : 1e-6,
|
||||||
"n_batch" : 3,
|
"n_batch" : 6,
|
||||||
"learning_rate": 1e-4,
|
"learning_rate": 1e-4,
|
||||||
"f1_threshold_classification": 0.8,
|
"f1_threshold_classification": 0.8,
|
||||||
"patches" : true,
|
"patches" : true,
|
||||||
|
@ -21,7 +21,7 @@
|
||||||
"scaling_flip" : false,
|
"scaling_flip" : false,
|
||||||
"rotation": false,
|
"rotation": false,
|
||||||
"rotation_not_90": false,
|
"rotation_not_90": false,
|
||||||
"num_patches_xy": [28, 28],
|
"transformer_num_patches_xy": [28, 28],
|
||||||
"transformer_patchsize": 1,
|
"transformer_patchsize": 1,
|
||||||
"blur_k" : ["blur","guass","median"],
|
"blur_k" : ["blur","guass","median"],
|
||||||
"scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
|
"scales" : [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.4],
|
||||||
|
@ -29,13 +29,14 @@
|
||||||
"degrade_scales" : [0.2, 0.4],
|
"degrade_scales" : [0.2, 0.4],
|
||||||
"flip_index" : [0, 1, -1],
|
"flip_index" : [0, 1, -1],
|
||||||
"thetha" : [10, -10],
|
"thetha" : [10, -10],
|
||||||
|
"classification_classes_name" : {"0":"apple", "1":"orange"},
|
||||||
"continue_training": false,
|
"continue_training": false,
|
||||||
"index_start" : 0,
|
"index_start" : 0,
|
||||||
"dir_of_start_model" : " ",
|
"dir_of_start_model" : " ",
|
||||||
"weighted_loss": false,
|
"weighted_loss": false,
|
||||||
"is_loss_soft_dice": false,
|
"is_loss_soft_dice": false,
|
||||||
"data_is_provided": false,
|
"data_is_provided": false,
|
||||||
"dir_train": "./training_data_sample_enhancement",
|
"dir_train": "./train",
|
||||||
"dir_eval": "./eval",
|
"dir_eval": "./eval",
|
||||||
"dir_output": "./out"
|
"dir_output": "./output"
|
||||||
}
|
}
|
||||||
|
|
490
inference.py
Normal file
490
inference.py
Normal file
|
@ -0,0 +1,490 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
|
__version__= '1.0'
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
import xml.etree.ElementTree as et
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
import csv
|
||||||
|
import cv2
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras import backend as K
|
||||||
|
from tensorflow.keras import layers
|
||||||
|
import tensorflow.keras.losses
|
||||||
|
from tensorflow.keras.layers import *
|
||||||
|
import click
|
||||||
|
import json
|
||||||
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
|
__doc__=\
|
||||||
|
"""
|
||||||
|
Tool to load model and predict for given image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
projection_dim = 64
|
||||||
|
patch_size = 1
|
||||||
|
num_patches =28*28
|
||||||
|
class Patches(layers.Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(Patches, self).__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
def call(self, images):
|
||||||
|
print(tf.shape(images)[1],'images')
|
||||||
|
print(self.patch_size,'self.patch_size')
|
||||||
|
batch_size = tf.shape(images)[0]
|
||||||
|
patches = tf.image.extract_patches(
|
||||||
|
images=images,
|
||||||
|
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||||
|
strides=[1, self.patch_size, self.patch_size, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
)
|
||||||
|
patch_dims = patches.shape[-1]
|
||||||
|
print(patches.shape,patch_dims,'patch_dims')
|
||||||
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
|
return patches
|
||||||
|
def get_config(self):
|
||||||
|
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'patch_size': self.patch_size,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEncoder(layers.Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(PatchEncoder, self).__init__()
|
||||||
|
self.num_patches = num_patches
|
||||||
|
self.projection = layers.Dense(units=projection_dim)
|
||||||
|
self.position_embedding = layers.Embedding(
|
||||||
|
input_dim=num_patches, output_dim=projection_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, patch):
|
||||||
|
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||||
|
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||||
|
return encoded
|
||||||
|
def get_config(self):
|
||||||
|
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'num_patches': self.num_patches,
|
||||||
|
'projection': self.projection,
|
||||||
|
'position_embedding': self.position_embedding,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class sbb_predict:
|
||||||
|
def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ):
|
||||||
|
self.image=image
|
||||||
|
self.patches=patches
|
||||||
|
self.save=save
|
||||||
|
self.model_dir=model
|
||||||
|
self.ground_truth=ground_truth
|
||||||
|
self.weights_dir=weights_dir
|
||||||
|
self.task=task
|
||||||
|
self.config_params_model=config_params_model
|
||||||
|
|
||||||
|
def resize_image(self,img_in,input_height,input_width):
|
||||||
|
return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
|
||||||
|
def color_images(self,seg):
|
||||||
|
ann_u=range(self.n_classes)
|
||||||
|
if len(np.shape(seg))==3:
|
||||||
|
seg=seg[:,:,0]
|
||||||
|
|
||||||
|
seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(np.uint8)
|
||||||
|
colors=sns.color_palette("hls", self.n_classes)
|
||||||
|
|
||||||
|
for c in ann_u:
|
||||||
|
c=int(c)
|
||||||
|
segl=(seg==c)
|
||||||
|
seg_img[:,:,0][seg==c]=c
|
||||||
|
seg_img[:,:,1][seg==c]=c
|
||||||
|
seg_img[:,:,2][seg==c]=c
|
||||||
|
return seg_img
|
||||||
|
|
||||||
|
def otsu_copy_binary(self,img):
|
||||||
|
img_r=np.zeros((img.shape[0],img.shape[1],3))
|
||||||
|
img1=img[:,:,0]
|
||||||
|
|
||||||
|
#print(img.min())
|
||||||
|
#print(img[:,:,0].min())
|
||||||
|
#blur = cv2.GaussianBlur(img,(5,5))
|
||||||
|
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
img_r[:,:,0]=threshold1
|
||||||
|
img_r[:,:,1]=threshold1
|
||||||
|
img_r[:,:,2]=threshold1
|
||||||
|
#img_r=img_r/float(np.max(img_r))*255
|
||||||
|
return img_r
|
||||||
|
|
||||||
|
def otsu_copy(self,img):
|
||||||
|
img_r=np.zeros((img.shape[0],img.shape[1],3))
|
||||||
|
#img1=img[:,:,0]
|
||||||
|
|
||||||
|
#print(img.min())
|
||||||
|
#print(img[:,:,0].min())
|
||||||
|
#blur = cv2.GaussianBlur(img,(5,5))
|
||||||
|
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
_, threshold1 = cv2.threshold(img[:,:,0], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
_, threshold2 = cv2.threshold(img[:,:,1], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
_, threshold3 = cv2.threshold(img[:,:,2], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
img_r[:,:,0]=threshold1
|
||||||
|
img_r[:,:,1]=threshold2
|
||||||
|
img_r[:,:,2]=threshold3
|
||||||
|
###img_r=img_r/float(np.max(img_r))*255
|
||||||
|
return img_r
|
||||||
|
|
||||||
|
def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6):
|
||||||
|
|
||||||
|
axes = tuple(range(1, len(y_pred.shape)-1))
|
||||||
|
|
||||||
|
numerator = 2. * K.sum(y_pred * y_true, axes)
|
||||||
|
|
||||||
|
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
||||||
|
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||||
|
|
||||||
|
def weighted_categorical_crossentropy(self,weights=None):
|
||||||
|
|
||||||
|
def loss(y_true, y_pred):
|
||||||
|
labels_floats = tf.cast(y_true, tf.float32)
|
||||||
|
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
||||||
|
np.array(weights, dtype=np.float32)[None, None, None])
|
||||||
|
* labels_floats, axis=-1), 1.0)
|
||||||
|
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||||
|
return tf.reduce_mean(per_pixel_loss)
|
||||||
|
return self.loss
|
||||||
|
|
||||||
|
|
||||||
|
def IoU(self,Yi,y_predi):
|
||||||
|
## mean Intersection over Union
|
||||||
|
## Mean IoU = TP/(FN + TP + FP)
|
||||||
|
|
||||||
|
IoUs = []
|
||||||
|
Nclass = np.unique(Yi)
|
||||||
|
for c in Nclass:
|
||||||
|
TP = np.sum( (Yi == c)&(y_predi==c) )
|
||||||
|
FP = np.sum( (Yi != c)&(y_predi==c) )
|
||||||
|
FN = np.sum( (Yi == c)&(y_predi != c))
|
||||||
|
IoU = TP/float(TP + FP + FN)
|
||||||
|
if self.n_classes>2:
|
||||||
|
print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU))
|
||||||
|
IoUs.append(IoU)
|
||||||
|
if self.n_classes>2:
|
||||||
|
mIoU = np.mean(IoUs)
|
||||||
|
print("_________________")
|
||||||
|
print("Mean IoU: {:4.3f}".format(mIoU))
|
||||||
|
return mIoU
|
||||||
|
elif self.n_classes==2:
|
||||||
|
mIoU = IoUs[1]
|
||||||
|
print("_________________")
|
||||||
|
print("IoU: {:4.3f}".format(mIoU))
|
||||||
|
return mIoU
|
||||||
|
|
||||||
|
def start_new_session_and_model(self):
|
||||||
|
|
||||||
|
config = tf.compat.v1.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
|
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||||
|
tensorflow_backend.set_session(session)
|
||||||
|
#tensorflow.keras.layers.custom_layer = PatchEncoder
|
||||||
|
#tensorflow.keras.layers.custom_layer = Patches
|
||||||
|
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
|
#config = tf.ConfigProto()
|
||||||
|
#config.gpu_options.allow_growth=True
|
||||||
|
|
||||||
|
#self.session = tf.InteractiveSession()
|
||||||
|
#keras.losses.custom_loss = self.weighted_categorical_crossentropy
|
||||||
|
#self.model = load_model(self.model_dir , compile=False)
|
||||||
|
|
||||||
|
|
||||||
|
##if self.weights_dir!=None:
|
||||||
|
##self.model.load_weights(self.weights_dir)
|
||||||
|
|
||||||
|
if self.task != 'classification':
|
||||||
|
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||||
|
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||||
|
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||||
|
|
||||||
|
def visualize_model_output(self, prediction, img, task):
|
||||||
|
if task == "binarization":
|
||||||
|
prediction = prediction * -1
|
||||||
|
prediction = prediction + 1
|
||||||
|
added_image = prediction * 255
|
||||||
|
else:
|
||||||
|
unique_classes = np.unique(prediction[:,:,0])
|
||||||
|
rgb_colors = {'0' : [255, 255, 255],
|
||||||
|
'1' : [255, 0, 0],
|
||||||
|
'2' : [255, 125, 0],
|
||||||
|
'3' : [255, 0, 125],
|
||||||
|
'4' : [125, 125, 125],
|
||||||
|
'5' : [125, 125, 0],
|
||||||
|
'6' : [0, 125, 255],
|
||||||
|
'7' : [0, 125, 0],
|
||||||
|
'8' : [125, 125, 125],
|
||||||
|
'9' : [0, 125, 255],
|
||||||
|
'10' : [125, 0, 125],
|
||||||
|
'11' : [0, 255, 0],
|
||||||
|
'12' : [0, 0, 255],
|
||||||
|
'13' : [0, 255, 255],
|
||||||
|
'14' : [255, 125, 125],
|
||||||
|
'15' : [255, 0, 255]}
|
||||||
|
|
||||||
|
output = np.zeros(prediction.shape)
|
||||||
|
|
||||||
|
for unq_class in unique_classes:
|
||||||
|
rgb_class_unique = rgb_colors[str(int(unq_class))]
|
||||||
|
output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0]
|
||||||
|
output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
|
||||||
|
output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
img = self.resize_image(img, output.shape[0], output.shape[1])
|
||||||
|
|
||||||
|
output = output.astype(np.int32)
|
||||||
|
img = img.astype(np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
added_image = cv2.addWeighted(img,0.5,output,0.1,0)
|
||||||
|
|
||||||
|
return added_image
|
||||||
|
|
||||||
|
def predict(self):
|
||||||
|
self.start_new_session_and_model()
|
||||||
|
if self.task == 'classification':
|
||||||
|
classes_names = self.config_params_model['classification_classes_name']
|
||||||
|
img_1ch = img=cv2.imread(self.image, 0)
|
||||||
|
|
||||||
|
img_1ch = img_1ch / 255.0
|
||||||
|
img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST)
|
||||||
|
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
|
||||||
|
img_in[0, :, :, 0] = img_1ch[:, :]
|
||||||
|
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||||
|
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||||
|
|
||||||
|
label_p_pred = self.model.predict(img_in, verbose=0)
|
||||||
|
index_class = np.argmax(label_p_pred[0])
|
||||||
|
|
||||||
|
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||||
|
else:
|
||||||
|
if self.patches:
|
||||||
|
#def textline_contours(img,input_width,input_height,n_classes,model):
|
||||||
|
|
||||||
|
img=cv2.imread(self.image)
|
||||||
|
self.img_org = np.copy(img)
|
||||||
|
|
||||||
|
if img.shape[0] < self.img_height:
|
||||||
|
img = cv2.resize(img, (img.shape[1], self.img_width), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
if img.shape[1] < self.img_width:
|
||||||
|
img = cv2.resize(img, (self.img_height, img.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
margin = int(0 * self.img_width)
|
||||||
|
width_mid = self.img_width - 2 * margin
|
||||||
|
height_mid = self.img_height - 2 * margin
|
||||||
|
img = img / float(255.0)
|
||||||
|
|
||||||
|
img_h = img.shape[0]
|
||||||
|
img_w = img.shape[1]
|
||||||
|
|
||||||
|
prediction_true = np.zeros((img_h, img_w, 3))
|
||||||
|
nxf = img_w / float(width_mid)
|
||||||
|
nyf = img_h / float(height_mid)
|
||||||
|
|
||||||
|
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
|
||||||
|
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
|
||||||
|
|
||||||
|
for i in range(nxf):
|
||||||
|
for j in range(nyf):
|
||||||
|
if i == 0:
|
||||||
|
index_x_d = i * width_mid
|
||||||
|
index_x_u = index_x_d + self.img_width
|
||||||
|
else:
|
||||||
|
index_x_d = i * width_mid
|
||||||
|
index_x_u = index_x_d + self.img_width
|
||||||
|
if j == 0:
|
||||||
|
index_y_d = j * height_mid
|
||||||
|
index_y_u = index_y_d + self.img_height
|
||||||
|
else:
|
||||||
|
index_y_d = j * height_mid
|
||||||
|
index_y_u = index_y_d + self.img_height
|
||||||
|
|
||||||
|
if index_x_u > img_w:
|
||||||
|
index_x_u = img_w
|
||||||
|
index_x_d = img_w - self.img_width
|
||||||
|
if index_y_u > img_h:
|
||||||
|
index_y_u = img_h
|
||||||
|
index_y_d = img_h - self.img_height
|
||||||
|
|
||||||
|
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||||
|
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
|
||||||
|
verbose=0)
|
||||||
|
|
||||||
|
if self.task == 'enhancement':
|
||||||
|
seg = label_p_pred[0, :, :, :]
|
||||||
|
seg = seg * 255
|
||||||
|
elif self.task == 'segmentation' or self.task == 'binarization':
|
||||||
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||||
|
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
|
|
||||||
|
|
||||||
|
if i == 0 and j == 0:
|
||||||
|
seg = seg[0 : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg
|
||||||
|
elif i == nxf - 1 and j == nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - 0]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg
|
||||||
|
elif i == 0 and j == nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - 0, 0 : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg
|
||||||
|
elif i == nxf - 1 and j == 0:
|
||||||
|
seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - 0]
|
||||||
|
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg
|
||||||
|
elif i == 0 and j != 0 and j != nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg
|
||||||
|
elif i == nxf - 1 and j != 0 and j != nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - 0]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg
|
||||||
|
elif i != 0 and i != nxf - 1 and j == 0:
|
||||||
|
seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg
|
||||||
|
elif i != 0 and i != nxf - 1 and j == nyf - 1:
|
||||||
|
seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg
|
||||||
|
else:
|
||||||
|
seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - margin]
|
||||||
|
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg
|
||||||
|
prediction_true = prediction_true.astype(int)
|
||||||
|
prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
return prediction_true
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
img=cv2.imread(self.image)
|
||||||
|
self.img_org = np.copy(img)
|
||||||
|
|
||||||
|
width=self.img_width
|
||||||
|
height=self.img_height
|
||||||
|
|
||||||
|
img=img/255.0
|
||||||
|
img=self.resize_image(img,self.img_height,self.img_width)
|
||||||
|
|
||||||
|
|
||||||
|
label_p_pred=self.model.predict(
|
||||||
|
img.reshape(1,img.shape[0],img.shape[1],img.shape[2]))
|
||||||
|
|
||||||
|
if self.task == 'enhancement':
|
||||||
|
seg = label_p_pred[0, :, :, :]
|
||||||
|
seg = seg * 255
|
||||||
|
elif self.task == 'segmentation' or self.task == 'binarization':
|
||||||
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||||
|
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
|
|
||||||
|
prediction_true = seg.astype(int)
|
||||||
|
|
||||||
|
prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
return prediction_true
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
res=self.predict()
|
||||||
|
if self.task == 'classification':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
|
cv2.imwrite('./test.png',img_seg_overlayed)
|
||||||
|
##if self.save!=None:
|
||||||
|
##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2)
|
||||||
|
##cv2.imwrite(self.save,img)
|
||||||
|
|
||||||
|
###if self.ground_truth!=None:
|
||||||
|
###gt_img=cv2.imread(self.ground_truth)
|
||||||
|
###self.IoU(gt_img[:,:,0],res)
|
||||||
|
##plt.imshow(res)
|
||||||
|
##plt.show()
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option(
|
||||||
|
"--image",
|
||||||
|
"-i",
|
||||||
|
help="image filename",
|
||||||
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--patches/--no-patches",
|
||||||
|
"-p/-nop",
|
||||||
|
is_flag=True,
|
||||||
|
help="if this parameter set to true, this tool will try to do inference in patches.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--save",
|
||||||
|
"-s",
|
||||||
|
help="save prediction as a png file in current folder.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
help="directory of models",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--ground_truth/--no-ground_truth",
|
||||||
|
"-gt/-nogt",
|
||||||
|
is_flag=True,
|
||||||
|
help="ground truth directory if you want to see the iou of prediction.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model_weights/--no-model_weights",
|
||||||
|
"-mw/-nomw",
|
||||||
|
is_flag=True,
|
||||||
|
help="previous model weights which are saved.",
|
||||||
|
)
|
||||||
|
def main(image, model, patches, save, ground_truth, model_weights):
|
||||||
|
|
||||||
|
with open(os.path.join(model,'config.json')) as f:
|
||||||
|
config_params_model = json.load(f)
|
||||||
|
task = 'classification'
|
||||||
|
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights)
|
||||||
|
x.run()
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
42
train.py
42
train.py
|
@ -69,7 +69,7 @@ def config_params():
|
||||||
flip_index = None # Flip image for augmentation.
|
flip_index = None # Flip image for augmentation.
|
||||||
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
continue_training = False # Set to true if you would like to continue training an already trained a model.
|
||||||
transformer_patchsize = None # Patch size of vision transformer patches.
|
transformer_patchsize = None # Patch size of vision transformer patches.
|
||||||
num_patches_xy = None # Number of patches for vision transformer.
|
transformer_num_patches_xy = None # Number of patches for vision transformer.
|
||||||
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
index_start = 0 # Index of model to continue training from. E.g. if you trained for 3 epochs and last index is 2, to continue from model_1.h5, set "index_start" to 3 to start naming model with index 3.
|
||||||
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
dir_of_start_model = '' # Directory containing pretrained encoder to continue training the model.
|
||||||
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
is_loss_soft_dice = False # Use soft dice as loss function. When set to true, "weighted_loss" must be false.
|
||||||
|
@ -77,6 +77,8 @@ def config_params():
|
||||||
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
data_is_provided = False # Only set this to true when you have already provided the input data and the train and eval data are in "dir_output".
|
||||||
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
|
task = "segmentation" # This parameter defines task of model which can be segmentation, enhancement or classification.
|
||||||
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||||
|
classification_classes_name = None # Dictionary of classification classes names.
|
||||||
|
backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer"
|
||||||
|
|
||||||
|
|
||||||
@ex.automain
|
@ex.automain
|
||||||
|
@ -89,12 +91,12 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
brightness, dir_train, data_is_provided, scaling_bluring,
|
brightness, dir_train, data_is_provided, scaling_bluring,
|
||||||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||||
thetha, scaling_flip, continue_training, transformer_patchsize,
|
thetha, scaling_flip, continue_training, transformer_patchsize,
|
||||||
num_patches_xy, model_name, flip_index, dir_eval, dir_output,
|
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
||||||
pretraining, learning_rate, task, f1_threshold_classification):
|
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name):
|
||||||
|
|
||||||
if task == "segmentation" or "enhancement":
|
if task == "segmentation" or task == "enhancement":
|
||||||
|
|
||||||
num_patches = num_patches_xy[0]*num_patches_xy[1]
|
num_patches = transformer_num_patches_xy[0]*transformer_num_patches_xy[1]
|
||||||
if data_is_provided:
|
if data_is_provided:
|
||||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||||
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
dir_eval_flowing = os.path.join(dir_output, 'eval')
|
||||||
|
@ -191,14 +193,14 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
weights = weights / float(np.sum(weights))
|
weights = weights / float(np.sum(weights))
|
||||||
|
|
||||||
if continue_training:
|
if continue_training:
|
||||||
if model_name=='resnet50_unet':
|
if backbone_type=='nontransformer':
|
||||||
if is_loss_soft_dice and task == "segmentation":
|
if is_loss_soft_dice and task == "segmentation":
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'soft_dice_loss': soft_dice_loss})
|
||||||
if weighted_loss and task == "segmentation":
|
if weighted_loss and task == "segmentation":
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
model = load_model(dir_of_start_model , compile=True)
|
model = load_model(dir_of_start_model , compile=True)
|
||||||
elif model_name=='hybrid_transformer_cnn':
|
elif backbone_type=='transformer':
|
||||||
if is_loss_soft_dice and task == "segmentation":
|
if is_loss_soft_dice and task == "segmentation":
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
model = load_model(dir_of_start_model, compile=True, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches,'soft_dice_loss': soft_dice_loss})
|
||||||
if weighted_loss and task == "segmentation":
|
if weighted_loss and task == "segmentation":
|
||||||
|
@ -207,9 +209,9 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
if model_name=='resnet50_unet':
|
if backbone_type=='nontransformer':
|
||||||
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
model = resnet50_unet(n_classes, input_height, input_width, task, weight_decay, pretraining)
|
||||||
elif model_name=='hybrid_transformer_cnn':
|
elif backbone_type=='nontransformer':
|
||||||
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
model = vit_resnet50_unet(n_classes, transformer_patchsize, num_patches, input_height, input_width, task, weight_decay, pretraining)
|
||||||
|
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
|
@ -246,9 +248,9 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
validation_data=val_gen,
|
validation_data=val_gen,
|
||||||
validation_steps=1,
|
validation_steps=1,
|
||||||
epochs=1)
|
epochs=1)
|
||||||
model.save(dir_output+'/'+'model_'+str(i))
|
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||||
|
|
||||||
with open(dir_output+'/'+'model_'+str(i)+'/'+"config.json", "w") as fp:
|
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
#os.system('rm -rf '+dir_train_flowing)
|
#os.system('rm -rf '+dir_train_flowing)
|
||||||
|
@ -257,14 +259,15 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
#model.save(dir_output+'/'+'model'+'.h5')
|
#model.save(dir_output+'/'+'model'+'.h5')
|
||||||
elif task=='classification':
|
elif task=='classification':
|
||||||
configuration()
|
configuration()
|
||||||
model = resnet50_classifier(n_classes, input_height, input_width,weight_decay,pretraining)
|
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
|
||||||
|
|
||||||
opt_adam = Adam(learning_rate=0.001)
|
opt_adam = Adam(learning_rate=0.001)
|
||||||
model.compile(loss='categorical_crossentropy',
|
model.compile(loss='categorical_crossentropy',
|
||||||
optimizer = opt_adam,metrics=['accuracy'])
|
optimizer = opt_adam,metrics=['accuracy'])
|
||||||
|
|
||||||
|
|
||||||
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes)
|
list_classes = list(classification_classes_name.values())
|
||||||
|
testX, testY = generate_data_from_folder_evaluation(dir_eval, input_height, input_width, n_classes, list_classes)
|
||||||
|
|
||||||
#print(testY.shape, testY)
|
#print(testY.shape, testY)
|
||||||
|
|
||||||
|
@ -280,7 +283,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
|
|
||||||
for i in range(n_epochs):
|
for i in range(n_epochs):
|
||||||
#history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights)
|
#history = model.fit(trainX, trainY, epochs=1, batch_size=n_batch, validation_data=(testX, testY), verbose=2)#,class_weight=weights)
|
||||||
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
|
history = model.fit( generate_data_from_folder_training(dir_train, n_batch , input_height, input_width, n_classes, list_classes), steps_per_epoch=num_rows / n_batch, verbose=0)#,class_weight=weights)
|
||||||
|
|
||||||
y_pr_class = []
|
y_pr_class = []
|
||||||
for jj in range(testY.shape[0]):
|
for jj in range(testY.shape[0]):
|
||||||
|
@ -301,10 +304,6 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
score_best[0]=f1score
|
score_best[0]=f1score
|
||||||
model.save(os.path.join(dir_output,'model_best'))
|
model.save(os.path.join(dir_output,'model_best'))
|
||||||
|
|
||||||
|
|
||||||
##best_model=keras.models.clone_model(model)
|
|
||||||
##best_model.build()
|
|
||||||
##best_model.set_weights(model.get_weights())
|
|
||||||
if f1score > f1_threshold_classification:
|
if f1score > f1_threshold_classification:
|
||||||
weights.append(model.get_weights() )
|
weights.append(model.get_weights() )
|
||||||
y_tot=y_tot+y_pr
|
y_tot=y_tot+y_pr
|
||||||
|
@ -329,4 +328,9 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
|
|
||||||
##best_model.save('model_taza.h5')
|
##best_model.save('model_taza.h5')
|
||||||
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
model_weight_averaged.save(os.path.join(dir_output,'model_ens_avg'))
|
||||||
|
with open(os.path.join( os.path.join(dir_output,'model_ens_avg'), "config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
with open(os.path.join( os.path.join(dir_output,'model_best'), "config.json"), "w") as fp:
|
||||||
|
json.dump(_config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
|
30
utils.py
30
utils.py
|
@ -21,14 +21,14 @@ def return_number_of_total_training_data(path_classes):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_data_from_folder_evaluation(path_classes, height, width, n_classes):
|
def generate_data_from_folder_evaluation(path_classes, height, width, n_classes, list_classes):
|
||||||
sub_classes = os.listdir(path_classes)
|
#sub_classes = os.listdir(path_classes)
|
||||||
#n_classes = len(sub_classes)
|
#n_classes = len(sub_classes)
|
||||||
all_imgs = []
|
all_imgs = []
|
||||||
labels = []
|
labels = []
|
||||||
dicts =dict()
|
#dicts =dict()
|
||||||
indexer= 0
|
#indexer= 0
|
||||||
for sub_c in sub_classes:
|
for indexer, sub_c in enumerate(list_classes):
|
||||||
sub_files = os.listdir(os.path.join(path_classes,sub_c ))
|
sub_files = os.listdir(os.path.join(path_classes,sub_c ))
|
||||||
sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files]
|
sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files]
|
||||||
#print( os.listdir(os.path.join(path_classes,sub_c )) )
|
#print( os.listdir(os.path.join(path_classes,sub_c )) )
|
||||||
|
@ -37,8 +37,8 @@ def generate_data_from_folder_evaluation(path_classes, height, width, n_classes)
|
||||||
|
|
||||||
#print( len(sub_labels) )
|
#print( len(sub_labels) )
|
||||||
labels = labels + sub_labels
|
labels = labels + sub_labels
|
||||||
dicts[sub_c] = indexer
|
#dicts[sub_c] = indexer
|
||||||
indexer +=1
|
#indexer +=1
|
||||||
|
|
||||||
|
|
||||||
categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ]
|
categories = to_categorical(range(n_classes)).astype(np.int16)#[ [1 , 0, 0 , 0 , 0 , 0] , [0 , 1, 0 , 0 , 0 , 0] , [0 , 0, 1 , 0 , 0 , 0] , [0 , 0, 0 , 1 , 0 , 0] , [0 , 0, 0 , 0 , 1 , 0] , [0 , 0, 0 , 0 , 0 , 1] ]
|
||||||
|
@ -64,15 +64,15 @@ def generate_data_from_folder_evaluation(path_classes, height, width, n_classes)
|
||||||
|
|
||||||
return ret_x/255., ret_y
|
return ret_x/255., ret_y
|
||||||
|
|
||||||
def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes):
|
def generate_data_from_folder_training(path_classes, batchsize, height, width, n_classes, list_classes):
|
||||||
sub_classes = os.listdir(path_classes)
|
#sub_classes = os.listdir(path_classes)
|
||||||
n_classes = len(sub_classes)
|
#n_classes = len(sub_classes)
|
||||||
|
|
||||||
all_imgs = []
|
all_imgs = []
|
||||||
labels = []
|
labels = []
|
||||||
dicts =dict()
|
#dicts =dict()
|
||||||
indexer= 0
|
#indexer= 0
|
||||||
for sub_c in sub_classes:
|
for indexer, sub_c in enumerate(list_classes):
|
||||||
sub_files = os.listdir(os.path.join(path_classes,sub_c ))
|
sub_files = os.listdir(os.path.join(path_classes,sub_c ))
|
||||||
sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files]
|
sub_files = [os.path.join(path_classes,sub_c )+'/' + x for x in sub_files]
|
||||||
#print( os.listdir(os.path.join(path_classes,sub_c )) )
|
#print( os.listdir(os.path.join(path_classes,sub_c )) )
|
||||||
|
@ -81,8 +81,8 @@ def generate_data_from_folder_training(path_classes, batchsize, height, width, n
|
||||||
|
|
||||||
#print( len(sub_labels) )
|
#print( len(sub_labels) )
|
||||||
labels = labels + sub_labels
|
labels = labels + sub_labels
|
||||||
dicts[sub_c] = indexer
|
#dicts[sub_c] = indexer
|
||||||
indexer +=1
|
#indexer +=1
|
||||||
|
|
||||||
ids = np.array(range(len(labels)))
|
ids = np.array(range(len(labels)))
|
||||||
random.shuffle(ids)
|
random.shuffle(ids)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue