mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
adopt binarizer to the zoo
This commit is contained in:
parent
de34a15809
commit
bcffa2e503
3 changed files with 62 additions and 30 deletions
|
|
@ -120,18 +120,39 @@ def machine_based_reading_order(input, dir_in, out, model, log_level):
|
||||||
type=click.Path(file_okay=True, dir_okay=True),
|
type=click.Path(file_okay=True, dir_okay=True),
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
'-M',
|
||||||
|
'--mode',
|
||||||
|
type=click.Choice(['single', 'multi']),
|
||||||
|
default='single',
|
||||||
|
help="Whether to use the (faster) single-model binarization or the (slightly better) multi-model binarization"
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--log_level",
|
"--log_level",
|
||||||
"-l",
|
"-l",
|
||||||
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
|
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
|
||||||
help="Override log level globally to this",
|
help="Override log level globally to this",
|
||||||
)
|
)
|
||||||
def binarization(patches, model_dir, input_image, dir_in, output, log_level):
|
def binarization(
|
||||||
|
patches,
|
||||||
|
model_dir,
|
||||||
|
input_image,
|
||||||
|
mode,
|
||||||
|
dir_in,
|
||||||
|
output,
|
||||||
|
log_level,
|
||||||
|
):
|
||||||
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||||
binarizer = SbbBinarizer(model_dir)
|
binarizer = SbbBinarizer(model_dir)
|
||||||
if log_level:
|
if log_level:
|
||||||
binarizer.log.setLevel(getLevelName(log_level))
|
binarizer.log.setLevel(getLevelName(log_level))
|
||||||
binarizer.run(image_path=input_image, use_patches=patches, output=output, dir_in=dir_in)
|
binarizer.run(
|
||||||
|
image_path=input_image,
|
||||||
|
use_patches=patches,
|
||||||
|
mode=mode,
|
||||||
|
output=output,
|
||||||
|
dir_in=dir_in
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,19 @@ DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
|
||||||
'': "eynollah-binarization_20210425"
|
'': "eynollah-binarization_20210425"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"binarization_multi_1": {
|
||||||
|
'': "saved_model_2020_01_16/model_bin1",
|
||||||
|
},
|
||||||
|
"binarization_multi_2": {
|
||||||
|
'': "saved_model_2020_01_16/model_bin2",
|
||||||
|
},
|
||||||
|
"binarization_multi_3": {
|
||||||
|
'': "saved_model_2020_01_16/model_bin3",
|
||||||
|
},
|
||||||
|
"binarization_multi_4": {
|
||||||
|
'': "saved_model_2020_01_16/model_bin4",
|
||||||
|
},
|
||||||
|
|
||||||
"col_classifier": {
|
"col_classifier": {
|
||||||
'': "eynollah-column-classifier_20210425",
|
'': "eynollah-column-classifier_20210425",
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -2,18 +2,19 @@
|
||||||
Tool to load model and binarize a given image.
|
Tool to load model and binarize a given image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
|
||||||
from glob import glob
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from keras.models import Model
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
import cv2
|
||||||
from ocrd_utils import tf_disable_interactive_logs
|
from ocrd_utils import tf_disable_interactive_logs
|
||||||
|
|
||||||
|
from eynollah.model_zoo import EynollahModelZoo
|
||||||
tf_disable_interactive_logs()
|
tf_disable_interactive_logs()
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import load_model
|
|
||||||
from tensorflow.python.keras import backend as tensorflow_backend
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
|
|
||||||
from .utils import is_image_filename
|
from .utils import is_image_filename
|
||||||
|
|
@ -23,40 +24,37 @@ def resize_image(img_in, input_height, input_width):
|
||||||
|
|
||||||
class SbbBinarizer:
|
class SbbBinarizer:
|
||||||
|
|
||||||
def __init__(self, model_dir, logger=None):
|
def __init__(self, model_dir, mode='single', logger=None):
|
||||||
self.model_dir = model_dir
|
if mode not in ('single', 'multi'):
|
||||||
|
raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}")
|
||||||
self.log = logger if logger else logging.getLogger('SbbBinarizer')
|
self.log = logger if logger else logging.getLogger('SbbBinarizer')
|
||||||
|
self.model_zoo = EynollahModelZoo(basedir=model_dir)
|
||||||
self.start_new_session()
|
self.models = self.setup_models(mode)
|
||||||
|
self.session = self.start_new_session()
|
||||||
self.model_files = glob(self.model_dir+"/*/", recursive = True)
|
|
||||||
|
|
||||||
self.models = []
|
|
||||||
for model_file in self.model_files:
|
|
||||||
self.models.append(self.load_model(model_file))
|
|
||||||
|
|
||||||
def start_new_session(self):
|
def start_new_session(self):
|
||||||
config = tf.compat.v1.ConfigProto()
|
config = tf.compat.v1.ConfigProto()
|
||||||
config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||||
tensorflow_backend.set_session(self.session)
|
tensorflow_backend.set_session(session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
def setup_models(self, mode: str) -> Dict[Path, Model]:
|
||||||
|
return {
|
||||||
|
self.model_zoo.model_path(v): self.model_zoo.load_model(v)
|
||||||
|
for v in (['binarization'] if mode == 'single' else [f'binarization_multi_{i}' for i in range(1, 5)])
|
||||||
|
}
|
||||||
|
|
||||||
def end_session(self):
|
def end_session(self):
|
||||||
tensorflow_backend.clear_session()
|
tensorflow_backend.clear_session()
|
||||||
self.session.close()
|
self.session.close()
|
||||||
del self.session
|
del self.session
|
||||||
|
|
||||||
def load_model(self, model_name):
|
def predict(self, img, use_patches, n_batch_inference=5):
|
||||||
model = load_model(os.path.join(self.model_dir, model_name), compile=False)
|
model = self.model_zoo.get('binarization', Model)
|
||||||
model_height = model.layers[len(model.layers)-1].output_shape[1]
|
model_height = model.layers[len(model.layers)-1].output_shape[1]
|
||||||
model_width = model.layers[len(model.layers)-1].output_shape[2]
|
model_width = model.layers[len(model.layers)-1].output_shape[2]
|
||||||
n_classes = model.layers[len(model.layers)-1].output_shape[3]
|
|
||||||
return model, model_height, model_width, n_classes
|
|
||||||
|
|
||||||
def predict(self, model_in, img, use_patches, n_batch_inference=5):
|
|
||||||
tensorflow_backend.set_session(self.session)
|
|
||||||
model, model_height, model_width, n_classes = model_in
|
|
||||||
|
|
||||||
img_org_h = img.shape[0]
|
img_org_h = img.shape[0]
|
||||||
img_org_w = img.shape[1]
|
img_org_w = img.shape[1]
|
||||||
|
|
@ -324,8 +322,8 @@ class SbbBinarizer:
|
||||||
if image_path is not None:
|
if image_path is not None:
|
||||||
image = cv2.imread(image_path)
|
image = cv2.imread(image_path)
|
||||||
img_last = 0
|
img_last = 0
|
||||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
for n, (model_file, model) in enumerate(self.models.items()):
|
||||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
|
||||||
|
|
||||||
res = self.predict(model, image, use_patches)
|
res = self.predict(model, image, use_patches)
|
||||||
|
|
||||||
|
|
@ -354,8 +352,8 @@ class SbbBinarizer:
|
||||||
print(image_name,'image_name')
|
print(image_name,'image_name')
|
||||||
image = cv2.imread(os.path.join(dir_in,image_name) )
|
image = cv2.imread(os.path.join(dir_in,image_name) )
|
||||||
img_last = 0
|
img_last = 0
|
||||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
for n, (model_file, model) in enumerate(self.models.items()):
|
||||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
|
||||||
|
|
||||||
res = self.predict(model, image, use_patches)
|
res = self.predict(model, image, use_patches)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue