Merge branch 'release-v0.0.2' into release-v0.0.3

pull/15/head
Konstantin Baierer 4 years ago
commit e46b89af85

2
.gitignore vendored

@ -1,2 +1,4 @@
*.egg-info *.egg-info
__pycache__ __pycache__
/build
/dist

@ -0,0 +1,23 @@
Change Log
==========
Versioned according to [Semantic Versioning](http://semver.org/).
## Unreleased
## 0.0.2
Changed:
* `SBB_BINARIZE_DATA` can replace `model` parameter, #6
Fixed:
* AlternativeImage/comments now set on page level, #8, #11
* Only try to load `*.h5` model files, #7, #10
## 0.0.1
Initial release
<!-- link-labels -->
[0.0.2]: ../../compare/v0.0.1...v0.0.2

@ -1,5 +1,5 @@
{ {
"version": "0.0.1", "version": "0.0.2",
"git_url": "https://github.com/qurator-spk/sbb_binarization", "git_url": "https://github.com/qurator-spk/sbb_binarization",
"tools": { "tools": {
"ocrd-sbb-binarize": { "ocrd-sbb-binarize": {
@ -19,7 +19,7 @@
"model": { "model": {
"description": "models directory.", "description": "models directory.",
"type": "string", "type": "string",
"required": true "required": false
} }
} }
} }

@ -37,6 +37,14 @@ class SbbBinarizeProcessor(Processor):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['ocrd_tool'] = OCRD_TOOL['tools'][TOOL] kwargs['ocrd_tool'] = OCRD_TOOL['tools'][TOOL]
kwargs['version'] = OCRD_TOOL['version'] kwargs['version'] = OCRD_TOOL['version']
if not(kwargs.get('show_help', None) or kwargs.get('dump_json', None) or kwargs.get('show_version')):
if not 'parameter' in kwargs:
kwargs['parameter'] = {}
if not 'model' in kwargs['parameter']:
if 'SBB_BINARIZE_DATA' in os.environ:
kwargs['parameter']['model'] = os.environ['SBB_BINARIZE_DATA']
else:
raise ValueError("Must pass 'model' parameter or set SBB_BINARIZE_DATA environment variable")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def process(self): def process(self):
@ -49,7 +57,7 @@ class SbbBinarizeProcessor(Processor):
oplevel = self.parameter['operation_level'] oplevel = self.parameter['operation_level']
model_path = self.parameter['model'] # pylint: disable=attribute-defined-outside-init model_path = self.parameter['model'] # pylint: disable=attribute-defined-outside-init
binarizer = SbbBinarizer(model_dir=model_path) binarizer = SbbBinarizer(model_dir=model_path, logger=LOG)
for n, input_file in enumerate(self.input_files): for n, input_file in enumerate(self.input_files):
file_id = make_file_id(input_file, self.output_file_grp) file_id = make_file_id(input_file, self.output_file_grp)
@ -69,7 +77,7 @@ class SbbBinarizeProcessor(Processor):
file_id + '.IMG-BIN', file_id + '.IMG-BIN',
page_id=input_file.pageId, page_id=input_file.pageId,
file_grp=self.output_file_grp) file_grp=self.output_file_grp)
page.add_AlternativeImage(AlternativeImageType(filename=bin_image_path, comment='%s,binarized' % page_xywh['features'])) page.add_AlternativeImage(AlternativeImageType(filename=bin_image_path, comments='%s,binarized' % page_xywh['features']))
elif oplevel == 'region': elif oplevel == 'region':
regions = page.get_AllRegions(['Text', 'Table'], depth=1) regions = page.get_AllRegions(['Text', 'Table'], depth=1)

@ -3,7 +3,8 @@ Tool to load model and binarize a given image.
""" """
import sys import sys
from os import listdir, environ, devnull from glob import glob
from os import environ, devnull
from os.path import join from os.path import join
from warnings import catch_warnings, simplefilter from warnings import catch_warnings, simplefilter
@ -17,13 +18,16 @@ from keras.models import load_model
sys.stderr = stderr sys.stderr = stderr
import tensorflow as tf import tensorflow as tf
import logging
def resize_image(img_in, input_height, input_width): def resize_image(img_in, input_height, input_width):
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
class SbbBinarizer: class SbbBinarizer:
def __init__(self, model_dir): def __init__(self, model_dir, logger=None):
self.model_dir = model_dir self.model_dir = model_dir
self.log = logger if logger else logging.getLogger('SbbBinarizer')
def start_new_session(self): def start_new_session(self):
config = tf.ConfigProto() config = tf.ConfigProto()
@ -191,9 +195,10 @@ class SbbBinarizer:
if image_path is not None: if image_path is not None:
image = cv2.imread(image_path) image = cv2.imread(image_path)
self.start_new_session() self.start_new_session()
list_of_model_files = listdir(self.model_dir) list_of_model_files = glob('%s/*.h5' % self.model_dir)
img_last = 0 img_last = 0
for model_in in list_of_model_files: for n, model_in in enumerate(list_of_model_files):
self.log.info('Predicting with model %s [%s/%s]' % (model_in, n + 1, len(list_of_model_files)))
res = self.predict(model_in, image, use_patches) res = self.predict(model_in, image, use_patches)

Loading…
Cancel
Save