diff --git a/inference.py b/inference.py index 6911bea..94e318d 100644 --- a/inference.py +++ b/inference.py @@ -1,25 +1,16 @@ -#! /usr/bin/env python3 - -__version__= '1.0' - -import argparse import sys import os import numpy as np import warnings -import xml.etree.ElementTree as et -import pandas as pd -from tqdm import tqdm -import csv import cv2 import seaborn as sns -import matplotlib.pyplot as plt from tensorflow.keras.models import load_model import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras import layers import tensorflow.keras.losses from tensorflow.keras.layers import * +from models import * import click import json from tensorflow.python.keras import backend as tensorflow_backend @@ -37,70 +28,13 @@ __doc__=\ Tool to load model and predict for given image. """ -projection_dim = 64 -patch_size = 1 -num_patches =28*28 -class Patches(layers.Layer): - def __init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size - - def call(self, images): - print(tf.shape(images)[1],'images') - print(self.patch_size,'self.patch_size') - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - patch_dims = patches.shape[-1] - print(patches.shape,patch_dims,'patch_dims') - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config - - -class PatchEncoder(layers.Layer): - def __init__(self, **kwargs): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config - - class sbb_predict: - def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ): + def __init__(self,image, model, task, config_params_model, patches, save, ground_truth): self.image=image self.patches=patches self.save=save self.model_dir=model self.ground_truth=ground_truth - self.weights_dir=weights_dir self.task=task self.config_params_model=config_params_model @@ -426,16 +360,12 @@ class sbb_predict: pass else: img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) - cv2.imwrite('./test.png',img_seg_overlayed) - ##if self.save!=None: - ##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2) - ##cv2.imwrite(self.save,img) - - ###if self.ground_truth!=None: - ###gt_img=cv2.imread(self.ground_truth) - ###self.IoU(gt_img[:,:,0],res) - ##plt.imshow(res) - ##plt.show() + if self.save: + cv2.imwrite(self.save,img_seg_overlayed) + + if self.ground_truth: + gt_img=cv2.imread(self.ground_truth) + self.IoU(gt_img[:,:,0],res[:,:,0]) @click.command() @click.option( @@ -463,23 +393,19 @@ class sbb_predict: required=True, ) @click.option( - "--ground_truth/--no-ground_truth", - "-gt/-nogt", - is_flag=True, + "--ground_truth", + "-gt", help="ground truth directory if you want to see the iou of prediction.", ) -@click.option( - "--model_weights/--no-model_weights", - "-mw/-nomw", - is_flag=True, - help="previous model weights which are saved.", -) -def main(image, model, patches, save, ground_truth, model_weights): - +def main(image, model, patches, save, ground_truth): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) - task = 'classification' - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights) + task = config_params_model['task'] + if task != 'classification': + if not save: + print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") + sys.exit(1) + x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth) x.run() if __name__=="__main__":