|
|
|
@ -7,6 +7,7 @@ from glob import glob
|
|
|
|
|
from os import environ, devnull
|
|
|
|
|
from os.path import join
|
|
|
|
|
from warnings import catch_warnings, simplefilter
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from PIL import Image
|
|
|
|
@ -242,7 +243,9 @@ class SbbBinarizer:
|
|
|
|
|
prediction_true = prediction_true.astype(np.uint8)
|
|
|
|
|
return prediction_true[:,:,0]
|
|
|
|
|
|
|
|
|
|
def run(self, image=None, image_path=None, save=None, use_patches=False):
|
|
|
|
|
def run(self, image=None, image_path=None, save=None, use_patches=False, dir_in=None, dir_out=None):
|
|
|
|
|
print(dir_in,'dir_in')
|
|
|
|
|
if not dir_in:
|
|
|
|
|
if (image is not None and image_path is not None) or \
|
|
|
|
|
(image is None and image_path is None):
|
|
|
|
|
raise ValueError("Must pass either a opencv2 image or an image_path")
|
|
|
|
@ -272,3 +275,31 @@ class SbbBinarizer:
|
|
|
|
|
if save:
|
|
|
|
|
cv2.imwrite(save, img_last)
|
|
|
|
|
return img_last
|
|
|
|
|
else:
|
|
|
|
|
ls_imgs = os.listdir(dir_in)
|
|
|
|
|
for image_name in ls_imgs:
|
|
|
|
|
print(image_name,'image_name')
|
|
|
|
|
image = cv2.imread(os.path.join(dir_in,image_name) )
|
|
|
|
|
img_last = 0
|
|
|
|
|
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
|
|
|
|
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
|
|
|
|
|
|
|
|
|
res = self.predict(model, image, use_patches)
|
|
|
|
|
|
|
|
|
|
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
|
|
|
|
|
res[:, :][res[:, :] == 0] = 2
|
|
|
|
|
res = res - 1
|
|
|
|
|
res = res * 255
|
|
|
|
|
img_fin[:, :, 0] = res
|
|
|
|
|
img_fin[:, :, 1] = res
|
|
|
|
|
img_fin[:, :, 2] = res
|
|
|
|
|
|
|
|
|
|
img_fin = img_fin.astype(np.uint8)
|
|
|
|
|
img_fin = (res[:, :] == 0) * 255
|
|
|
|
|
img_last = img_last + img_fin
|
|
|
|
|
|
|
|
|
|
kernel = np.ones((5, 5), np.uint8)
|
|
|
|
|
img_last[:, :][img_last[:, :] > 0] = 255
|
|
|
|
|
img_last = (img_last[:, :] == 0) * 255
|
|
|
|
|
|
|
|
|
|
cv2.imwrite(os.path.join(dir_out,image_name), img_last)
|
|
|
|
|