|
|
|
@ -1,25 +1,16 @@
|
|
|
|
|
#! /usr/bin/env python3
|
|
|
|
|
|
|
|
|
|
__version__= '1.0'
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import sys
|
|
|
|
|
import os
|
|
|
|
|
import numpy as np
|
|
|
|
|
import warnings
|
|
|
|
|
import xml.etree.ElementTree as et
|
|
|
|
|
import pandas as pd
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
import csv
|
|
|
|
|
import cv2
|
|
|
|
|
import seaborn as sns
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
from tensorflow.keras.models import load_model
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
from tensorflow.keras import backend as K
|
|
|
|
|
from tensorflow.keras import layers
|
|
|
|
|
import tensorflow.keras.losses
|
|
|
|
|
from tensorflow.keras.layers import *
|
|
|
|
|
from models import *
|
|
|
|
|
import click
|
|
|
|
|
import json
|
|
|
|
|
from tensorflow.python.keras import backend as tensorflow_backend
|
|
|
|
@ -37,70 +28,13 @@ __doc__=\
|
|
|
|
|
Tool to load model and predict for given image.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
projection_dim = 64
|
|
|
|
|
patch_size = 1
|
|
|
|
|
num_patches =28*28
|
|
|
|
|
class Patches(layers.Layer):
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
super(Patches, self).__init__()
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
def call(self, images):
|
|
|
|
|
print(tf.shape(images)[1],'images')
|
|
|
|
|
print(self.patch_size,'self.patch_size')
|
|
|
|
|
batch_size = tf.shape(images)[0]
|
|
|
|
|
patches = tf.image.extract_patches(
|
|
|
|
|
images=images,
|
|
|
|
|
sizes=[1, self.patch_size, self.patch_size, 1],
|
|
|
|
|
strides=[1, self.patch_size, self.patch_size, 1],
|
|
|
|
|
rates=[1, 1, 1, 1],
|
|
|
|
|
padding="VALID",
|
|
|
|
|
)
|
|
|
|
|
patch_dims = patches.shape[-1]
|
|
|
|
|
print(patches.shape,patch_dims,'patch_dims')
|
|
|
|
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
|
|
|
|
return patches
|
|
|
|
|
def get_config(self):
|
|
|
|
|
|
|
|
|
|
config = super().get_config().copy()
|
|
|
|
|
config.update({
|
|
|
|
|
'patch_size': self.patch_size,
|
|
|
|
|
})
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEncoder(layers.Layer):
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
super(PatchEncoder, self).__init__()
|
|
|
|
|
self.num_patches = num_patches
|
|
|
|
|
self.projection = layers.Dense(units=projection_dim)
|
|
|
|
|
self.position_embedding = layers.Embedding(
|
|
|
|
|
input_dim=num_patches, output_dim=projection_dim
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def call(self, patch):
|
|
|
|
|
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
|
|
|
|
encoded = self.projection(patch) + self.position_embedding(positions)
|
|
|
|
|
return encoded
|
|
|
|
|
def get_config(self):
|
|
|
|
|
|
|
|
|
|
config = super().get_config().copy()
|
|
|
|
|
config.update({
|
|
|
|
|
'num_patches': self.num_patches,
|
|
|
|
|
'projection': self.projection,
|
|
|
|
|
'position_embedding': self.position_embedding,
|
|
|
|
|
})
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class sbb_predict:
|
|
|
|
|
def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ):
|
|
|
|
|
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
|
|
|
|
|
self.image=image
|
|
|
|
|
self.patches=patches
|
|
|
|
|
self.save=save
|
|
|
|
|
self.model_dir=model
|
|
|
|
|
self.ground_truth=ground_truth
|
|
|
|
|
self.weights_dir=weights_dir
|
|
|
|
|
self.task=task
|
|
|
|
|
self.config_params_model=config_params_model
|
|
|
|
|
|
|
|
|
@ -426,16 +360,12 @@ class sbb_predict:
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
|
|
|
|
|
cv2.imwrite('./test.png',img_seg_overlayed)
|
|
|
|
|
##if self.save!=None:
|
|
|
|
|
##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2)
|
|
|
|
|
##cv2.imwrite(self.save,img)
|
|
|
|
|
|
|
|
|
|
###if self.ground_truth!=None:
|
|
|
|
|
###gt_img=cv2.imread(self.ground_truth)
|
|
|
|
|
###self.IoU(gt_img[:,:,0],res)
|
|
|
|
|
##plt.imshow(res)
|
|
|
|
|
##plt.show()
|
|
|
|
|
if self.save:
|
|
|
|
|
cv2.imwrite(self.save,img_seg_overlayed)
|
|
|
|
|
|
|
|
|
|
if self.ground_truth:
|
|
|
|
|
gt_img=cv2.imread(self.ground_truth)
|
|
|
|
|
self.IoU(gt_img[:,:,0],res[:,:,0])
|
|
|
|
|
|
|
|
|
|
@click.command()
|
|
|
|
|
@click.option(
|
|
|
|
@ -463,23 +393,19 @@ class sbb_predict:
|
|
|
|
|
required=True,
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--ground_truth/--no-ground_truth",
|
|
|
|
|
"-gt/-nogt",
|
|
|
|
|
is_flag=True,
|
|
|
|
|
"--ground_truth",
|
|
|
|
|
"-gt",
|
|
|
|
|
help="ground truth directory if you want to see the iou of prediction.",
|
|
|
|
|
)
|
|
|
|
|
@click.option(
|
|
|
|
|
"--model_weights/--no-model_weights",
|
|
|
|
|
"-mw/-nomw",
|
|
|
|
|
is_flag=True,
|
|
|
|
|
help="previous model weights which are saved.",
|
|
|
|
|
)
|
|
|
|
|
def main(image, model, patches, save, ground_truth, model_weights):
|
|
|
|
|
|
|
|
|
|
def main(image, model, patches, save, ground_truth):
|
|
|
|
|
with open(os.path.join(model,'config.json')) as f:
|
|
|
|
|
config_params_model = json.load(f)
|
|
|
|
|
task = 'classification'
|
|
|
|
|
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights)
|
|
|
|
|
task = config_params_model['task']
|
|
|
|
|
if task != 'classification':
|
|
|
|
|
if not save:
|
|
|
|
|
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
|
|
|
|
|
x.run()
|
|
|
|
|
|
|
|
|
|
if __name__=="__main__":
|
|
|
|
|