diff --git a/sbb_binarize/cli.py b/sbb_binarize/cli.py index 20881b5..1b3bc7e 100644 --- a/sbb_binarize/cli.py +++ b/sbb_binarize/cli.py @@ -16,13 +16,8 @@ def main(): options = parser.parse_args() - binarizer = SbbBinarizer( - image_path=options.image, - model=options.model, - patches=options.patches, - save=options.save - ) - binarizer.run() + binarizer = SbbBinarizer(model_dir=options.model) + binarizer.run(image_path=options.image, patches=options.patches, save=options.save) if __name__ == "__main__": main() diff --git a/sbb_binarize/ocrd_cli.py b/sbb_binarize/ocrd_cli.py index d846212..854586b 100644 --- a/sbb_binarize/ocrd_cli.py +++ b/sbb_binarize/ocrd_cli.py @@ -39,14 +39,6 @@ class SbbBinarizeProcessor(Processor): kwargs['version'] = OCRD_TOOL['version'] super().__init__(*args, **kwargs) - def _run_binarizer(self, img): - return cv2pil( - SbbBinarizer( - image=pil2cv(img), - model=self.model_path, - patches=self.use_patches, - save=None).run()) - def process(self): """ Binarize with sbb_binarization @@ -56,8 +48,9 @@ class SbbBinarizeProcessor(Processor): assert_file_grp_cardinality(self.output_file_grp, 1) oplevel = self.parameter['operation_level'] - self.use_patches = self.parameter['patches'] # pylint: disable=attribute-defined-outside-init - self.model_path = self.parameter['model'] # pylint: disable=attribute-defined-outside-init + use_patches = self.parameter['patches'] # pylint: disable=attribute-defined-outside-init + model_path = self.parameter['model'] # pylint: disable=attribute-defined-outside-init + binarizer = SbbBinarizer(model_dir=self.model_path) for n, input_file in enumerate(self.input_files): file_id = make_file_id(input_file, self.output_file_grp) @@ -71,7 +64,7 @@ class SbbBinarizeProcessor(Processor): if oplevel == 'page': LOG.info("Binarizing on 'page' level in page '%s'", page_id) page_image, page_xywh, _ = self.workspace.image_from_page(page, page_id, feature_filter='binarized') - bin_image = self._run_binarizer(page_image) + bin_image = cv2pil(binarizer.run(image=pil2cv(page_image), patches=use_patches)) # update METS (add the image file): bin_image_path = self.workspace.save_image_file(bin_image, file_id + '.IMG-BIN', @@ -85,7 +78,7 @@ class SbbBinarizeProcessor(Processor): LOG.warning("Page '%s' contains no text/table regions", page_id) for region in regions: region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh, feature_filter='binarized') - region_image_bin = self._run_binarizer(region_image) + region_image_bin = cv2pil(binarizer.run(image=pil2cv(region_image), patches=use_patches)) region_image_bin_path = self.workspace.save_image_file( region_image_bin, "%s_%s.IMG-BIN" % (file_id, region.id), @@ -100,7 +93,7 @@ class SbbBinarizeProcessor(Processor): LOG.warning("Page '%s' contains no text lines", page_id) for region_id, line in region_line_tuples: line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized') - line_image_bin = self._run_binarizer(line_image) + line_image_bin = cv2pil(binarizer.run(image=pil2cv(line_image), patches=use_patches)) line_image_bin_path = self.workspace.save_image_file( line_image_bin, "%s_%s_%s.IMG-BIN" % (file_id, region_id, line.id), diff --git a/sbb_binarize/sbb_binarize.py b/sbb_binarize/sbb_binarize.py index 9769456..51bccae 100644 --- a/sbb_binarize/sbb_binarize.py +++ b/sbb_binarize/sbb_binarize.py @@ -22,50 +22,35 @@ def resize_image(img_in, input_height, input_width): class SbbBinarizer: - # TODO use True/False for patches - def __init__(self, model, image=None, image_path=None, patches='false', save=None): - if (image is not None and image_path is not None) or \ - (image is None and image_path is None): - raise ValueError("Must pass either a opencv2 image or an image_path") - if image is not None: - self.image = image - else: - self.image = cv2.imread(self.image) - self.patches = patches - self.save = save - self.model_dir = model + def __init__(self, model_dir): + self.model_dir = model_dir - def start_new_session_and_model(self): + def start_new_session(self): config = tf.ConfigProto() config.gpu_options.allow_growth = True self.session = tf.Session(config=config) # tf.InteractiveSession() - def load_model(self, model_name): - - self.model = load_model(join(self.model_dir, model_name), compile=False) - - self.img_height = self.model.layers[len(self.model.layers)-1].output_shape[1] - self.img_width = self.model.layers[len(self.model.layers)-1].output_shape[2] - self.n_classes = self.model.layers[len(self.model.layers)-1].output_shape[3] - def end_session(self): self.session.close() - del self.model del self.session - def predict(self,model_name): - self.load_model(model_name) - img = self.image - img_width_model = self.img_width - img_height_model = self.img_height + def load_model(self, model_name): + model = load_model(join(self.model_dir, model_name), compile=False) + model_height = model.layers[len(model.layers)-1].output_shape[1] + model_width = model.layers[len(model.layers)-1].output_shape[2] + n_classes = model.layers[len(model.layers)-1].output_shape[3] + return model, model_height, model_width, n_classes + + def predict(self, model_name, img, patches): + model, model_height, model_width, n_classes = self.load_model(model_name) - if self.patches in ('true', 'True'): + if patches in ('true', 'True'): - margin = int(0.1 * img_width_model) + margin = int(0.1 * model_width) - width_mid = img_width_model - 2 * margin - height_mid = img_height_model - 2 * margin + width_mid = model_width - 2 * margin + height_mid = model_height - 2 * margin img = img / float(255.0) @@ -93,28 +78,28 @@ class SbbBinarizer: if i == 0: index_x_d = i * width_mid - index_x_u = index_x_d + img_width_model + index_x_u = index_x_d + model_width elif i > 0: index_x_d = i * width_mid - index_x_u = index_x_d + img_width_model + index_x_u = index_x_d + model_width if j == 0: index_y_d = j * height_mid - index_y_u = index_y_d + img_height_model + index_y_u = index_y_d + model_height elif j > 0: index_y_d = j * height_mid - index_y_u = index_y_d + img_height_model + index_y_u = index_y_d + model_height if index_x_u > img_w: index_x_u = img_w - index_x_d = img_w - img_width_model + index_x_d = img_w - model_width if index_y_u > img_h: index_y_u = img_h - index_y_d = img_h - img_height_model + index_y_d = img_h - model_height img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) + label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) seg = np.argmax(label_p_pred, axis=3)[0] @@ -189,10 +174,9 @@ class SbbBinarizer: img_h_page = img.shape[0] img_w_page = img.shape[1] img = img / float(255.0) - img = resize_image(img, img_height_model, img_width_model) + img = resize_image(img, model_height, model_width) - label_p_pred = self.model.predict( - img.reshape(1, img.shape[0], img.shape[1], img.shape[2])) + label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2])) seg = np.argmax(label_p_pred, axis=3)[0] seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) @@ -200,29 +184,35 @@ class SbbBinarizer: prediction_true = prediction_true.astype(np.uint8) return prediction_true[:,:,0] - def run(self): - self.start_new_session_and_model() - models_n = listdir(self.model_dir) + # TODO use True/False for patches + def run(self, image=None, image_path=None, save=None, patches='false'): + if (image is not None and image_path is not None) or \ + (image is None and image_path is None): + raise ValueError("Must pass either a opencv2 image or an image_path") + if image_path is not None: + image = cv2.imread(image) + self.start_new_session() + list_of_model_files = listdir(self.model_dir) img_last = 0 - for model_in in models_n: + for model_in in list_of_model_files: - res = self.predict(model_in) + res = self.predict(model_in, image, patches) img_fin = np.zeros((res.shape[0], res.shape[1], 3)) res[:, :][res[:, :] == 0] = 2 - res = res-1 - res = res*255 + res = res - 1 + res = res * 255 img_fin[:, :, 0] = res img_fin[:, :, 1] = res img_fin[:, :, 2] = res img_fin = img_fin.astype(np.uint8) - img_fin = (res[:, :] == 0)*255 - img_last = img_last+img_fin + img_fin = (res[:, :] == 0) * 255 + img_last = img_last + img_fin kernel = np.ones((5, 5), np.uint8) img_last[:, :][img_last[:, :] > 0] = 255 - img_last = (img_last[:, :] == 0)*255 - if self.save: - cv2.imwrite(self.save, img_last) + img_last = (img_last[:, :] == 0) * 255 + if save: + cv2.imwrite(save, img_last) return img_last