mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
modifications
This commit is contained in:
parent
8d1050ec30
commit
ce1108aca0
1 changed files with 17 additions and 91 deletions
108
inference.py
108
inference.py
|
@ -1,25 +1,16 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
__version__= '1.0'
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import warnings
|
||||
import xml.etree.ElementTree as et
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import csv
|
||||
import cv2
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from tensorflow.keras.models import load_model
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import backend as K
|
||||
from tensorflow.keras import layers
|
||||
import tensorflow.keras.losses
|
||||
from tensorflow.keras.layers import *
|
||||
from models import *
|
||||
import click
|
||||
import json
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
|
@ -37,70 +28,13 @@ __doc__=\
|
|||
Tool to load model and predict for given image.
|
||||
"""
|
||||
|
||||
projection_dim = 64
|
||||
patch_size = 1
|
||||
num_patches =28*28
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
print(tf.shape(images)[1],'images')
|
||||
print(self.patch_size,'self.patch_size')
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
print(patches.shape,patch_dims,'patch_dims')
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(PatchEncoder, self).__init__()
|
||||
self.num_patches = num_patches
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(
|
||||
input_dim=num_patches, output_dim=projection_dim
|
||||
)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': self.num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
class sbb_predict:
|
||||
def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ):
|
||||
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
|
||||
self.image=image
|
||||
self.patches=patches
|
||||
self.save=save
|
||||
self.model_dir=model
|
||||
self.ground_truth=ground_truth
|
||||
self.weights_dir=weights_dir
|
||||
self.task=task
|
||||
self.config_params_model=config_params_model
|
||||
|
||||
|
@ -426,16 +360,12 @@ class sbb_predict:
|
|||
pass
|
||||
else:
|
||||
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
|
||||
cv2.imwrite('./test.png',img_seg_overlayed)
|
||||
##if self.save!=None:
|
||||
##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2)
|
||||
##cv2.imwrite(self.save,img)
|
||||
|
||||
###if self.ground_truth!=None:
|
||||
###gt_img=cv2.imread(self.ground_truth)
|
||||
###self.IoU(gt_img[:,:,0],res)
|
||||
##plt.imshow(res)
|
||||
##plt.show()
|
||||
if self.save:
|
||||
cv2.imwrite(self.save,img_seg_overlayed)
|
||||
|
||||
if self.ground_truth:
|
||||
gt_img=cv2.imread(self.ground_truth)
|
||||
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
|
@ -463,23 +393,19 @@ class sbb_predict:
|
|||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--ground_truth/--no-ground_truth",
|
||||
"-gt/-nogt",
|
||||
is_flag=True,
|
||||
"--ground_truth",
|
||||
"-gt",
|
||||
help="ground truth directory if you want to see the iou of prediction.",
|
||||
)
|
||||
@click.option(
|
||||
"--model_weights/--no-model_weights",
|
||||
"-mw/-nomw",
|
||||
is_flag=True,
|
||||
help="previous model weights which are saved.",
|
||||
)
|
||||
def main(image, model, patches, save, ground_truth, model_weights):
|
||||
|
||||
def main(image, model, patches, save, ground_truth):
|
||||
with open(os.path.join(model,'config.json')) as f:
|
||||
config_params_model = json.load(f)
|
||||
task = 'classification'
|
||||
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights)
|
||||
task = config_params_model['task']
|
||||
if task != 'classification':
|
||||
if not save:
|
||||
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
|
||||
sys.exit(1)
|
||||
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
|
||||
x.run()
|
||||
|
||||
if __name__=="__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue