replace patches string comparison with use_patches boolean

pull/5/head
Konstantin Baierer 4 years ago
parent fad7b7aff1
commit 2bc6ccc4c0

@ -50,7 +50,7 @@ class SbbBinarizeProcessor(Processor):
oplevel = self.parameter['operation_level'] oplevel = self.parameter['operation_level']
use_patches = self.parameter['patches'] # pylint: disable=attribute-defined-outside-init use_patches = self.parameter['patches'] # pylint: disable=attribute-defined-outside-init
model_path = self.parameter['model'] # pylint: disable=attribute-defined-outside-init model_path = self.parameter['model'] # pylint: disable=attribute-defined-outside-init
binarizer = SbbBinarizer(model_dir=self.model_path) binarizer = SbbBinarizer(model_dir=model_path)
for n, input_file in enumerate(self.input_files): for n, input_file in enumerate(self.input_files):
file_id = make_file_id(input_file, self.output_file_grp) file_id = make_file_id(input_file, self.output_file_grp)
@ -64,7 +64,7 @@ class SbbBinarizeProcessor(Processor):
if oplevel == 'page': if oplevel == 'page':
LOG.info("Binarizing on 'page' level in page '%s'", page_id) LOG.info("Binarizing on 'page' level in page '%s'", page_id)
page_image, page_xywh, _ = self.workspace.image_from_page(page, page_id, feature_filter='binarized') page_image, page_xywh, _ = self.workspace.image_from_page(page, page_id, feature_filter='binarized')
bin_image = cv2pil(binarizer.run(image=pil2cv(page_image), patches=use_patches)) bin_image = cv2pil(binarizer.run(image=pil2cv(page_image), use_patches=use_patches))
# update METS (add the image file): # update METS (add the image file):
bin_image_path = self.workspace.save_image_file(bin_image, bin_image_path = self.workspace.save_image_file(bin_image,
file_id + '.IMG-BIN', file_id + '.IMG-BIN',
@ -78,7 +78,7 @@ class SbbBinarizeProcessor(Processor):
LOG.warning("Page '%s' contains no text/table regions", page_id) LOG.warning("Page '%s' contains no text/table regions", page_id)
for region in regions: for region in regions:
region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh, feature_filter='binarized') region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh, feature_filter='binarized')
region_image_bin = cv2pil(binarizer.run(image=pil2cv(region_image), patches=use_patches)) region_image_bin = cv2pil(binarizer.run(image=pil2cv(region_image), use_patches=use_patches))
region_image_bin_path = self.workspace.save_image_file( region_image_bin_path = self.workspace.save_image_file(
region_image_bin, region_image_bin,
"%s_%s.IMG-BIN" % (file_id, region.id), "%s_%s.IMG-BIN" % (file_id, region.id),
@ -93,7 +93,7 @@ class SbbBinarizeProcessor(Processor):
LOG.warning("Page '%s' contains no text lines", page_id) LOG.warning("Page '%s' contains no text lines", page_id)
for region_id, line in region_line_tuples: for region_id, line in region_line_tuples:
line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized') line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized')
line_image_bin = cv2pil(binarizer.run(image=pil2cv(line_image), patches=use_patches)) line_image_bin = cv2pil(binarizer.run(image=pil2cv(line_image), use_patches=use_patches))
line_image_bin_path = self.workspace.save_image_file( line_image_bin_path = self.workspace.save_image_file(
line_image_bin, line_image_bin,
"%s_%s_%s.IMG-BIN" % (file_id, region_id, line.id), "%s_%s_%s.IMG-BIN" % (file_id, region_id, line.id),

@ -42,10 +42,10 @@ class SbbBinarizer:
n_classes = model.layers[len(model.layers)-1].output_shape[3] n_classes = model.layers[len(model.layers)-1].output_shape[3]
return model, model_height, model_width, n_classes return model, model_height, model_width, n_classes
def predict(self, model_name, img, patches): def predict(self, model_name, img, use_patches):
model, model_height, model_width, n_classes = self.load_model(model_name) model, model_height, model_width, n_classes = self.load_model(model_name)
if patches in ('true', 'True'): if use_patches:
margin = int(0.1 * model_width) margin = int(0.1 * model_width)
@ -184,8 +184,7 @@ class SbbBinarizer:
prediction_true = prediction_true.astype(np.uint8) prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:,:,0] return prediction_true[:,:,0]
# TODO use True/False for patches def run(self, image=None, image_path=None, save=None, use_patches=False):
def run(self, image=None, image_path=None, save=None, patches='false'):
if (image is not None and image_path is not None) or \ if (image is not None and image_path is not None) or \
(image is None and image_path is None): (image is None and image_path is None):
raise ValueError("Must pass either a opencv2 image or an image_path") raise ValueError("Must pass either a opencv2 image or an image_path")
@ -196,7 +195,7 @@ class SbbBinarizer:
img_last = 0 img_last = 0
for model_in in list_of_model_files: for model_in in list_of_model_files:
res = self.predict(model_in, image, patches) res = self.predict(model_in, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3)) img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2 res[:, :][res[:, :] == 0] = 2

Loading…
Cancel
Save