diff --git a/sbb_binarize/cli.py b/sbb_binarize/cli.py index 0176e20..ea1b3c4 100644 --- a/sbb_binarize/cli.py +++ b/sbb_binarize/cli.py @@ -4,6 +4,7 @@ sbb_binarize CLI from click import command, option, argument, version_option, types from .sbb_binarize import SbbBinarizer +import click @command() @version_option() @@ -11,5 +12,24 @@ from .sbb_binarize import SbbBinarizer @option('--model-dir', '-m', type=click.Path(exists=True, file_okay=False), required=True, help='directory containing models for prediction') @argument('input_image') @argument('output_image') -def main(patches, model_dir, input_image, output_image): - SbbBinarizer(model_dir).run(image_path=input_image, use_patches=patches, save=output_image) +@click.option( + "--dir_in", + "-di", + help="directory of images", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--dir_out", + "-do", + help="directory where the binarized images will be written", + type=click.Path(exists=True, file_okay=False), +) + +def main(patches, model_dir, input_image, output_image, dir_in, dir_out): + if not dir_out and (dir_in): + print("Error: You used -di but did not set -do") + sys.exit(1) + elif dir_out and not (dir_in): + print("Error: You used -do to write out binarized images but have not set -di") + sys.exit(1) + SbbBinarizer(model_dir).run(image_path=input_image, use_patches=patches, save=output_image, dir_in=dir_in, dir_out=dir_out) diff --git a/sbb_binarize/sbb_binarize.py b/sbb_binarize/sbb_binarize.py index 5424098..e56f3b1 100644 --- a/sbb_binarize/sbb_binarize.py +++ b/sbb_binarize/sbb_binarize.py @@ -7,6 +7,7 @@ from glob import glob from os import environ, devnull from os.path import join from warnings import catch_warnings, simplefilter +import os import numpy as np from PIL import Image @@ -242,33 +243,63 @@ class SbbBinarizer: prediction_true = prediction_true.astype(np.uint8) return prediction_true[:,:,0] - def run(self, image=None, image_path=None, save=None, use_patches=False): - if (image is not None and image_path is not None) or \ - (image is None and image_path is None): - raise ValueError("Must pass either a opencv2 image or an image_path") - if image_path is not None: - image = cv2.imread(image_path) - img_last = 0 - for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): - self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) - - res = self.predict(model, image, use_patches) - - img_fin = np.zeros((res.shape[0], res.shape[1], 3)) - res[:, :][res[:, :] == 0] = 2 - res = res - 1 - res = res * 255 - img_fin[:, :, 0] = res - img_fin[:, :, 1] = res - img_fin[:, :, 2] = res - - img_fin = img_fin.astype(np.uint8) - img_fin = (res[:, :] == 0) * 255 - img_last = img_last + img_fin - - kernel = np.ones((5, 5), np.uint8) - img_last[:, :][img_last[:, :] > 0] = 255 - img_last = (img_last[:, :] == 0) * 255 - if save: - cv2.imwrite(save, img_last) - return img_last + def run(self, image=None, image_path=None, save=None, use_patches=False, dir_in=None, dir_out=None): + print(dir_in,'dir_in') + if not dir_in: + if (image is not None and image_path is not None) or \ + (image is None and image_path is None): + raise ValueError("Must pass either a opencv2 image or an image_path") + if image_path is not None: + image = cv2.imread(image_path) + img_last = 0 + for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): + self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) + + res = self.predict(model, image, use_patches) + + img_fin = np.zeros((res.shape[0], res.shape[1], 3)) + res[:, :][res[:, :] == 0] = 2 + res = res - 1 + res = res * 255 + img_fin[:, :, 0] = res + img_fin[:, :, 1] = res + img_fin[:, :, 2] = res + + img_fin = img_fin.astype(np.uint8) + img_fin = (res[:, :] == 0) * 255 + img_last = img_last + img_fin + + kernel = np.ones((5, 5), np.uint8) + img_last[:, :][img_last[:, :] > 0] = 255 + img_last = (img_last[:, :] == 0) * 255 + if save: + cv2.imwrite(save, img_last) + return img_last + else: + ls_imgs = os.listdir(dir_in) + for image_name in ls_imgs: + print(image_name,'image_name') + image = cv2.imread(os.path.join(dir_in,image_name) ) + img_last = 0 + for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): + self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) + + res = self.predict(model, image, use_patches) + + img_fin = np.zeros((res.shape[0], res.shape[1], 3)) + res[:, :][res[:, :] == 0] = 2 + res = res - 1 + res = res * 255 + img_fin[:, :, 0] = res + img_fin[:, :, 1] = res + img_fin[:, :, 2] = res + + img_fin = img_fin.astype(np.uint8) + img_fin = (res[:, :] == 0) * 255 + img_last = img_last + img_fin + + kernel = np.ones((5, 5), np.uint8) + img_last[:, :][img_last[:, :] > 0] = 255 + img_last = (img_last[:, :] == 0) * 255 + + cv2.imwrite(os.path.join(dir_out,image_name), img_last)