allow passing image directly, return image on binarize

pull/5/head
Konstantin Baierer 4 years ago
parent 389ef088d0
commit ca03844c2b

@ -16,7 +16,12 @@ def main():
options = parser.parse_args()
binarizer = SbbBinarizer(options.image, options.model, options.patches, options.save)
binarizer = SbbBinarizer(
image_path=options.image,
model=options.model,
patches=options.patches,
save=options.save
)
binarizer.run()
if __name__ == "__main__":

@ -15,25 +15,30 @@ import tensorflow as tf
with catch_warnings():
simplefilter("ignore")
def resize_image(img_in, input_height, input_width):
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
class SbbBinarizer:
# TODO use True/False for patches
def __init__(self, image, model, patches='false', save=None):
self.image = image
def __init__(self, model, image=None, image_path=None, patches='false', save=None):
if not(image or image_path) or (image and image_path):
raise ValueError("Must pass either a PIL image or an image_path")
if image:
self.image = image
else:
self.image = cv2.imread(self.image)
self.patches = patches
self.save = save
self.model_dir = model
def resize_image(self,img_in,input_height,input_width):
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
def start_new_session_and_model(self):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
self.session = tf.Session(config=config) # tf.InteractiveSession()
def load_model(self,model_name):
def load_model(self, model_name):
self.model = load_model(join(self.model_dir, model_name), compile=False)
@ -48,11 +53,11 @@ class SbbBinarizer:
def predict(self,model_name):
self.load_model(model_name)
img = cv2.imread(self.image)
img = self.image
img_width_model = self.img_width
img_height_model = self.img_height
if self.patches=='true' or self.patches=='True':
if self.patches in ('true', 'True'):
margin = int(0.1 * img_width_model)
@ -181,14 +186,14 @@ class SbbBinarizer:
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / float(255.0)
img = self.resize_image(img, img_height_model, img_width_model)
img = resize_image(img, img_height_model, img_width_model)
label_p_pred = self.model.predict(
img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
seg = np.argmax(label_p_pred, axis=3)[0]
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = self.resize_image(seg_color, img_h_page, img_w_page)
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:,:,0]
@ -217,3 +222,4 @@ class SbbBinarizer:
img_last = (img_last[:, :] == 0)*255
if self.save:
cv2.imwrite(self.save, img_last)
return img_last

Loading…
Cancel
Save