diff --git a/sbb_binarize/ocrd_cli.py b/sbb_binarize/ocrd_cli.py index 57438d3..9737bad 100644 --- a/sbb_binarize/ocrd_cli.py +++ b/sbb_binarize/ocrd_cli.py @@ -30,6 +30,10 @@ def cv2pil(img): def pil2cv(img): # from ocrd/workspace.py + if img.mode in ('LA', 'RGBA'): + newimg = Image.new(img.mode[:-1], img.size, 'white') + newimg.paste(img, mask=img.getchannel('A')) + img = newimg color_conversion = cv2.COLOR_GRAY2BGR if img.mode in ('1', 'L') else cv2.COLOR_RGB2BGR pil_as_np_array = np.array(img).astype('uint8') if img.mode == '1' else np.array(img) return cv2.cvtColor(pil_as_np_array, color_conversion) diff --git a/sbb_binarize/sbb_binarize.py b/sbb_binarize/sbb_binarize.py index 8960354..247d54b 100644 --- a/sbb_binarize/sbb_binarize.py +++ b/sbb_binarize/sbb_binarize.py @@ -34,6 +34,8 @@ class SbbBinarizer: self.start_new_session() self.model_files = glob('%s/*.h5' % self.model_dir) + if not self.model_files: + self.model_files = glob('%s/*/' % self.model_dir) if not self.model_files: raise ValueError(f"No models found in {self.model_dir}")