mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
modifications
This commit is contained in:
parent
8d1050ec30
commit
ce1108aca0
1 changed files with 17 additions and 91 deletions
108
inference.py
108
inference.py
|
@ -1,25 +1,16 @@
|
||||||
#! /usr/bin/env python3
|
|
||||||
|
|
||||||
__version__= '1.0'
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import warnings
|
import warnings
|
||||||
import xml.etree.ElementTree as et
|
|
||||||
import pandas as pd
|
|
||||||
from tqdm import tqdm
|
|
||||||
import csv
|
|
||||||
import cv2
|
import cv2
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras import backend as K
|
from tensorflow.keras import backend as K
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
import tensorflow.keras.losses
|
import tensorflow.keras.losses
|
||||||
from tensorflow.keras.layers import *
|
from tensorflow.keras.layers import *
|
||||||
|
from models import *
|
||||||
import click
|
import click
|
||||||
import json
|
import json
|
||||||
from tensorflow.python.keras import backend as tensorflow_backend
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
|
@ -37,70 +28,13 @@ __doc__=\
|
||||||
Tool to load model and predict for given image.
|
Tool to load model and predict for given image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
projection_dim = 64
|
|
||||||
patch_size = 1
|
|
||||||
num_patches =28*28
|
|
||||||
class Patches(layers.Layer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(Patches, self).__init__()
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
def call(self, images):
|
|
||||||
print(tf.shape(images)[1],'images')
|
|
||||||
print(self.patch_size,'self.patch_size')
|
|
||||||
batch_size = tf.shape(images)[0]
|
|
||||||
patches = tf.image.extract_patches(
|
|
||||||
images=images,
|
|
||||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
|
||||||
strides=[1, self.patch_size, self.patch_size, 1],
|
|
||||||
rates=[1, 1, 1, 1],
|
|
||||||
padding="VALID",
|
|
||||||
)
|
|
||||||
patch_dims = patches.shape[-1]
|
|
||||||
print(patches.shape,patch_dims,'patch_dims')
|
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
|
||||||
return patches
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'patch_size': self.patch_size,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(PatchEncoder, self).__init__()
|
|
||||||
self.num_patches = num_patches
|
|
||||||
self.projection = layers.Dense(units=projection_dim)
|
|
||||||
self.position_embedding = layers.Embedding(
|
|
||||||
input_dim=num_patches, output_dim=projection_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, patch):
|
|
||||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
|
||||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
|
||||||
return encoded
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'num_patches': self.num_patches,
|
|
||||||
'projection': self.projection,
|
|
||||||
'position_embedding': self.position_embedding,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class sbb_predict:
|
class sbb_predict:
|
||||||
def __init__(self,image, model, task, config_params_model, patches='false',save='false', ground_truth=None,weights_dir=None ):
|
def __init__(self,image, model, task, config_params_model, patches, save, ground_truth):
|
||||||
self.image=image
|
self.image=image
|
||||||
self.patches=patches
|
self.patches=patches
|
||||||
self.save=save
|
self.save=save
|
||||||
self.model_dir=model
|
self.model_dir=model
|
||||||
self.ground_truth=ground_truth
|
self.ground_truth=ground_truth
|
||||||
self.weights_dir=weights_dir
|
|
||||||
self.task=task
|
self.task=task
|
||||||
self.config_params_model=config_params_model
|
self.config_params_model=config_params_model
|
||||||
|
|
||||||
|
@ -426,16 +360,12 @@ class sbb_predict:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
cv2.imwrite('./test.png',img_seg_overlayed)
|
if self.save:
|
||||||
##if self.save!=None:
|
cv2.imwrite(self.save,img_seg_overlayed)
|
||||||
##img=np.repeat(res[:, :, np.newaxis]*255, 3, axis=2)
|
|
||||||
##cv2.imwrite(self.save,img)
|
if self.ground_truth:
|
||||||
|
gt_img=cv2.imread(self.ground_truth)
|
||||||
###if self.ground_truth!=None:
|
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||||
###gt_img=cv2.imread(self.ground_truth)
|
|
||||||
###self.IoU(gt_img[:,:,0],res)
|
|
||||||
##plt.imshow(res)
|
|
||||||
##plt.show()
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
@ -463,23 +393,19 @@ class sbb_predict:
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--ground_truth/--no-ground_truth",
|
"--ground_truth",
|
||||||
"-gt/-nogt",
|
"-gt",
|
||||||
is_flag=True,
|
|
||||||
help="ground truth directory if you want to see the iou of prediction.",
|
help="ground truth directory if you want to see the iou of prediction.",
|
||||||
)
|
)
|
||||||
@click.option(
|
def main(image, model, patches, save, ground_truth):
|
||||||
"--model_weights/--no-model_weights",
|
|
||||||
"-mw/-nomw",
|
|
||||||
is_flag=True,
|
|
||||||
help="previous model weights which are saved.",
|
|
||||||
)
|
|
||||||
def main(image, model, patches, save, ground_truth, model_weights):
|
|
||||||
|
|
||||||
with open(os.path.join(model,'config.json')) as f:
|
with open(os.path.join(model,'config.json')) as f:
|
||||||
config_params_model = json.load(f)
|
config_params_model = json.load(f)
|
||||||
task = 'classification'
|
task = config_params_model['task']
|
||||||
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, model_weights)
|
if task != 'classification':
|
||||||
|
if not save:
|
||||||
|
print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s")
|
||||||
|
sys.exit(1)
|
||||||
|
x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth)
|
||||||
x.run()
|
x.run()
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue