allow passing image directly, return image on binarize

pull/5/head
Konstantin Baierer 4 years ago
parent 389ef088d0
commit ca03844c2b

@ -16,7 +16,12 @@ def main():
options = parser.parse_args() options = parser.parse_args()
binarizer = SbbBinarizer(options.image, options.model, options.patches, options.save) binarizer = SbbBinarizer(
image_path=options.image,
model=options.model,
patches=options.patches,
save=options.save
)
binarizer.run() binarizer.run()
if __name__ == "__main__": if __name__ == "__main__":

@ -15,18 +15,23 @@ import tensorflow as tf
with catch_warnings(): with catch_warnings():
simplefilter("ignore") simplefilter("ignore")
def resize_image(img_in, input_height, input_width):
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
class SbbBinarizer: class SbbBinarizer:
# TODO use True/False for patches # TODO use True/False for patches
def __init__(self, image, model, patches='false', save=None): def __init__(self, model, image=None, image_path=None, patches='false', save=None):
if not(image or image_path) or (image and image_path):
raise ValueError("Must pass either a PIL image or an image_path")
if image:
self.image = image self.image = image
else:
self.image = cv2.imread(self.image)
self.patches = patches self.patches = patches
self.save = save self.save = save
self.model_dir = model self.model_dir = model
def resize_image(self,img_in,input_height,input_width):
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
def start_new_session_and_model(self): def start_new_session_and_model(self):
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
@ -48,11 +53,11 @@ class SbbBinarizer:
def predict(self,model_name): def predict(self,model_name):
self.load_model(model_name) self.load_model(model_name)
img = cv2.imread(self.image) img = self.image
img_width_model = self.img_width img_width_model = self.img_width
img_height_model = self.img_height img_height_model = self.img_height
if self.patches=='true' or self.patches=='True': if self.patches in ('true', 'True'):
margin = int(0.1 * img_width_model) margin = int(0.1 * img_width_model)
@ -181,14 +186,14 @@ class SbbBinarizer:
img_h_page = img.shape[0] img_h_page = img.shape[0]
img_w_page = img.shape[1] img_w_page = img.shape[1]
img = img / float(255.0) img = img / float(255.0)
img = self.resize_image(img, img_height_model, img_width_model) img = resize_image(img, img_height_model, img_width_model)
label_p_pred = self.model.predict( label_p_pred = self.model.predict(
img.reshape(1, img.shape[0], img.shape[1], img.shape[2])) img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
seg = np.argmax(label_p_pred, axis=3)[0] seg = np.argmax(label_p_pred, axis=3)[0]
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = self.resize_image(seg_color, img_h_page, img_w_page) prediction_true = resize_image(seg_color, img_h_page, img_w_page)
prediction_true = prediction_true.astype(np.uint8) prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:,:,0] return prediction_true[:,:,0]
@ -217,3 +222,4 @@ class SbbBinarizer:
img_last = (img_last[:, :] == 0)*255 img_last = (img_last[:, :] == 0)*255
if self.save: if self.save:
cv2.imwrite(self.save, img_last) cv2.imwrite(self.save, img_last)
return img_last

Loading…
Cancel
Save