|
|
@ -15,25 +15,30 @@ import tensorflow as tf
|
|
|
|
with catch_warnings():
|
|
|
|
with catch_warnings():
|
|
|
|
simplefilter("ignore")
|
|
|
|
simplefilter("ignore")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_image(img_in, input_height, input_width):
|
|
|
|
|
|
|
|
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
|
|
|
|
|
|
|
|
|
|
|
class SbbBinarizer:
|
|
|
|
class SbbBinarizer:
|
|
|
|
|
|
|
|
|
|
|
|
# TODO use True/False for patches
|
|
|
|
# TODO use True/False for patches
|
|
|
|
def __init__(self, image, model, patches='false', save=None):
|
|
|
|
def __init__(self, model, image=None, image_path=None, patches='false', save=None):
|
|
|
|
self.image = image
|
|
|
|
if not(image or image_path) or (image and image_path):
|
|
|
|
|
|
|
|
raise ValueError("Must pass either a PIL image or an image_path")
|
|
|
|
|
|
|
|
if image:
|
|
|
|
|
|
|
|
self.image = image
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.image = cv2.imread(self.image)
|
|
|
|
self.patches = patches
|
|
|
|
self.patches = patches
|
|
|
|
self.save = save
|
|
|
|
self.save = save
|
|
|
|
self.model_dir = model
|
|
|
|
self.model_dir = model
|
|
|
|
|
|
|
|
|
|
|
|
def resize_image(self,img_in,input_height,input_width):
|
|
|
|
|
|
|
|
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_new_session_and_model(self):
|
|
|
|
def start_new_session_and_model(self):
|
|
|
|
config = tf.ConfigProto()
|
|
|
|
config = tf.ConfigProto()
|
|
|
|
config.gpu_options.allow_growth = True
|
|
|
|
config.gpu_options.allow_growth = True
|
|
|
|
|
|
|
|
|
|
|
|
self.session = tf.Session(config=config) # tf.InteractiveSession()
|
|
|
|
self.session = tf.Session(config=config) # tf.InteractiveSession()
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(self,model_name):
|
|
|
|
def load_model(self, model_name):
|
|
|
|
|
|
|
|
|
|
|
|
self.model = load_model(join(self.model_dir, model_name), compile=False)
|
|
|
|
self.model = load_model(join(self.model_dir, model_name), compile=False)
|
|
|
|
|
|
|
|
|
|
|
@ -48,11 +53,11 @@ class SbbBinarizer:
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self,model_name):
|
|
|
|
def predict(self,model_name):
|
|
|
|
self.load_model(model_name)
|
|
|
|
self.load_model(model_name)
|
|
|
|
img = cv2.imread(self.image)
|
|
|
|
img = self.image
|
|
|
|
img_width_model = self.img_width
|
|
|
|
img_width_model = self.img_width
|
|
|
|
img_height_model = self.img_height
|
|
|
|
img_height_model = self.img_height
|
|
|
|
|
|
|
|
|
|
|
|
if self.patches=='true' or self.patches=='True':
|
|
|
|
if self.patches in ('true', 'True'):
|
|
|
|
|
|
|
|
|
|
|
|
margin = int(0.1 * img_width_model)
|
|
|
|
margin = int(0.1 * img_width_model)
|
|
|
|
|
|
|
|
|
|
|
@ -181,14 +186,14 @@ class SbbBinarizer:
|
|
|
|
img_h_page = img.shape[0]
|
|
|
|
img_h_page = img.shape[0]
|
|
|
|
img_w_page = img.shape[1]
|
|
|
|
img_w_page = img.shape[1]
|
|
|
|
img = img / float(255.0)
|
|
|
|
img = img / float(255.0)
|
|
|
|
img = self.resize_image(img, img_height_model, img_width_model)
|
|
|
|
img = resize_image(img, img_height_model, img_width_model)
|
|
|
|
|
|
|
|
|
|
|
|
label_p_pred = self.model.predict(
|
|
|
|
label_p_pred = self.model.predict(
|
|
|
|
img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
|
|
|
|
img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
|
|
|
|
|
|
|
|
|
|
|
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
|
|
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
|
|
|
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
|
|
|
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
|
|
|
prediction_true = self.resize_image(seg_color, img_h_page, img_w_page)
|
|
|
|
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
|
|
|
|
prediction_true = prediction_true.astype(np.uint8)
|
|
|
|
prediction_true = prediction_true.astype(np.uint8)
|
|
|
|
return prediction_true[:,:,0]
|
|
|
|
return prediction_true[:,:,0]
|
|
|
|
|
|
|
|
|
|
|
@ -217,3 +222,4 @@ class SbbBinarizer:
|
|
|
|
img_last = (img_last[:, :] == 0)*255
|
|
|
|
img_last = (img_last[:, :] == 0)*255
|
|
|
|
if self.save:
|
|
|
|
if self.save:
|
|
|
|
cv2.imwrite(self.save, img_last)
|
|
|
|
cv2.imwrite(self.save, img_last)
|
|
|
|
|
|
|
|
return img_last
|
|
|
|