modifications

pull/18/head
vahidrezanezhad 8 months ago
parent 8d1050ec30
commit ce1108aca0

@ -1,25 +1,16 @@
#! /usr/bin/env python3
__version__= '1.0'
import argparse
import sys
import os
import numpy as np
import warnings
import xml.etree.ElementTree as et
import pandas as pd
from tqdm import tqdm
import csv
import cv2
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
from models import *
import click
import json
from tensorflow.python.keras import backend as tensorflow_backend
@ -37,70 +28,13 @@ __doc__=\
Tool to load model and predict for given image.
"""
projection_dim = 64
patch_size = 1
num_patches =28*28
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
print(tf.shape(images)[1],'images')
print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
print(patches.shape,patch_dims,'patch_dims')
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class sbb_predict:
def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ):
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
self.image=image
self.patches=patches
self.save=save
self.model_dir=model
self.ground_truth=ground_truth
self.weights_dir=weights_dir
self.task=task
self.config_params_model=config_params_model
@ -426,16 +360,12 @@ class sbb_predict:
pass
else:
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
cv2.imwrite('./test.png',img_seg_overlayed)
##if self.save!=None:
##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2)
##cv2.imwrite(self.save,img)
if self.save:
cv2.imwrite(self.save,img_seg_overlayed)
###if self.ground_truth!=None:
###gt_img=cv2.imread(self.ground_truth)
###self.IoU(gt_img[:,:,0],res)
##plt.imshow(res)
##plt.show()
if self.ground_truth:
gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0])
@click.command()
@click.option(
@ -463,23 +393,19 @@ class sbb_predict:
required=True,
)
@click.option(
"--ground_truth/--no-ground_truth",
"-gt/-nogt",
is_flag=True,
"--ground_truth",
"-gt",
help="ground truth directory if you want to see the iou of prediction.",
)
@click.option(
"--model_weights/--no-model_weights",
"-mw/-nomw",
is_flag=True,
help="previous model weights which are saved.",
)
def main(image, model, patches, save, ground_truth, model_weights):
def main(image, model, patches, save, ground_truth):
with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f)
task = 'classification'
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights)
task = config_params_model['task']
if task != 'classification':
if not save:
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
sys.exit(1)
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
x.run()
if __name__=="__main__":

Loading…
Cancel
Save