modifications

pull/18/head
vahidrezanezhad 8 months ago
parent 8d1050ec30
commit ce1108aca0

@ -1,25 +1,16 @@
#! /usr/bin/env python3
__version__= '1.0'
import argparse
import sys import sys
import os import os
import numpy as np import numpy as np
import warnings import warnings
import xml.etree.ElementTree as et
import pandas as pd
from tqdm import tqdm
import csv
import cv2 import cv2
import seaborn as sns import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
from tensorflow.keras import layers from tensorflow.keras import layers
import tensorflow.keras.losses import tensorflow.keras.losses
from tensorflow.keras.layers import * from tensorflow.keras.layers import *
from models import *
import click import click
import json import json
from tensorflow.python.keras import backend as tensorflow_backend from tensorflow.python.keras import backend as tensorflow_backend
@ -37,70 +28,13 @@ __doc__=\
Tool to load model and predict for given image. Tool to load model and predict for given image.
""" """
projection_dim = 64
patch_size = 1
num_patches =28*28
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
print(tf.shape(images)[1],'images')
print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
print(patches.shape,patch_dims,'patch_dims')
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class sbb_predict: class sbb_predict:
def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ): def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
self.image=image self.image=image
self.patches=patches self.patches=patches
self.save=save self.save=save
self.model_dir=model self.model_dir=model
self.ground_truth=ground_truth self.ground_truth=ground_truth
self.weights_dir=weights_dir
self.task=task self.task=task
self.config_params_model=config_params_model self.config_params_model=config_params_model
@ -426,16 +360,12 @@ class sbb_predict:
pass pass
else: else:
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
cv2.imwrite('./test.png',img_seg_overlayed) if self.save:
##if self.save!=None: cv2.imwrite(self.save,img_seg_overlayed)
##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2)
##cv2.imwrite(self.save,img)
###if self.ground_truth!=None: if self.ground_truth:
###gt_img=cv2.imread(self.ground_truth) gt_img=cv2.imread(self.ground_truth)
###self.IoU(gt_img[:,:,0],res) self.IoU(gt_img[:,:,0],res[:,:,0])
##plt.imshow(res)
##plt.show()
@click.command() @click.command()
@click.option( @click.option(
@ -463,23 +393,19 @@ class sbb_predict:
required=True, required=True,
) )
@click.option( @click.option(
"--ground_truth/--no-ground_truth", "--ground_truth",
"-gt/-nogt", "-gt",
is_flag=True,
help="ground truth directory if you want to see the iou of prediction.", help="ground truth directory if you want to see the iou of prediction.",
) )
@click.option( def main(image, model, patches, save, ground_truth):
"--model_weights/--no-model_weights",
"-mw/-nomw",
is_flag=True,
help="previous model weights which are saved.",
)
def main(image, model, patches, save, ground_truth, model_weights):
with open(os.path.join(model,'config.json')) as f: with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f) config_params_model = json.load(f)
task = 'classification' task = config_params_model['task']
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights) if task != 'classification':
if not save:
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
sys.exit(1)
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
x.run() x.run()
if __name__=="__main__": if __name__=="__main__":

Loading…
Cancel
Save