diff --git a/sbb_binarize/cli.py b/sbb_binarize/cli.py index 3de8820..20881b5 100644 --- a/sbb_binarize/cli.py +++ b/sbb_binarize/cli.py @@ -16,7 +16,12 @@ def main(): options = parser.parse_args() - binarizer = SbbBinarizer(options.image, options.model, options.patches, options.save) + binarizer = SbbBinarizer( + image_path=options.image, + model=options.model, + patches=options.patches, + save=options.save + ) binarizer.run() if __name__ == "__main__": diff --git a/sbb_binarize/sbb_binarize.py b/sbb_binarize/sbb_binarize.py index 6c8f8fd..70a81cf 100644 --- a/sbb_binarize/sbb_binarize.py +++ b/sbb_binarize/sbb_binarize.py @@ -15,25 +15,30 @@ import tensorflow as tf with catch_warnings(): simplefilter("ignore") +def resize_image(img_in, input_height, input_width): + return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) + class SbbBinarizer: # TODO use True/False for patches - def __init__(self, image, model, patches='false', save=None): - self.image = image + def __init__(self, model, image=None, image_path=None, patches='false', save=None): + if not(image or image_path) or (image and image_path): + raise ValueError("Must pass either a PIL image or an image_path") + if image: + self.image = image + else: + self.image = cv2.imread(self.image) self.patches = patches self.save = save self.model_dir = model - def resize_image(self,img_in,input_height,input_width): - return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) - def start_new_session_and_model(self): config = tf.ConfigProto() config.gpu_options.allow_growth = True self.session = tf.Session(config=config) # tf.InteractiveSession() - def load_model(self,model_name): + def load_model(self, model_name): self.model = load_model(join(self.model_dir, model_name), compile=False) @@ -48,11 +53,11 @@ class SbbBinarizer: def predict(self,model_name): self.load_model(model_name) - img = cv2.imread(self.image) + img = self.image img_width_model = self.img_width img_height_model = self.img_height - if self.patches=='true' or self.patches=='True': + if self.patches in ('true', 'True'): margin = int(0.1 * img_width_model) @@ -181,14 +186,14 @@ class SbbBinarizer: img_h_page = img.shape[0] img_w_page = img.shape[1] img = img / float(255.0) - img = self.resize_image(img, img_height_model, img_width_model) + img = resize_image(img, img_height_model, img_width_model) label_p_pred = self.model.predict( img.reshape(1, img.shape[0], img.shape[1], img.shape[2])) seg = np.argmax(label_p_pred, axis=3)[0] seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) - prediction_true = self.resize_image(seg_color, img_h_page, img_w_page) + prediction_true = resize_image(seg_color, img_h_page, img_w_page) prediction_true = prediction_true.astype(np.uint8) return prediction_true[:,:,0] @@ -217,3 +222,4 @@ class SbbBinarizer: img_last = (img_last[:, :] == 0)*255 if self.save: cv2.imwrite(self.save, img_last) + return img_last