mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
adopt binarizer to the zoo
This commit is contained in:
parent
de34a15809
commit
bcffa2e503
3 changed files with 62 additions and 30 deletions
|
|
@ -120,18 +120,39 @@ def machine_based_reading_order(input, dir_in, out, model, log_level):
|
|||
type=click.Path(file_okay=True, dir_okay=True),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
'-M',
|
||||
'--mode',
|
||||
type=click.Choice(['single', 'multi']),
|
||||
default='single',
|
||||
help="Whether to use the (faster) single-model binarization or the (slightly better) multi-model binarization"
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
"-l",
|
||||
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
|
||||
help="Override log level globally to this",
|
||||
)
|
||||
def binarization(patches, model_dir, input_image, dir_in, output, log_level):
|
||||
def binarization(
|
||||
patches,
|
||||
model_dir,
|
||||
input_image,
|
||||
mode,
|
||||
dir_in,
|
||||
output,
|
||||
log_level,
|
||||
):
|
||||
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||
binarizer = SbbBinarizer(model_dir)
|
||||
if log_level:
|
||||
binarizer.log.setLevel(getLevelName(log_level))
|
||||
binarizer.run(image_path=input_image, use_patches=patches, output=output, dir_in=dir_in)
|
||||
binarizer.run(
|
||||
image_path=input_image,
|
||||
use_patches=patches,
|
||||
mode=mode,
|
||||
output=output,
|
||||
dir_in=dir_in
|
||||
)
|
||||
|
||||
|
||||
@main.command()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,19 @@ DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
|
|||
'': "eynollah-binarization_20210425"
|
||||
},
|
||||
|
||||
"binarization_multi_1": {
|
||||
'': "saved_model_2020_01_16/model_bin1",
|
||||
},
|
||||
"binarization_multi_2": {
|
||||
'': "saved_model_2020_01_16/model_bin2",
|
||||
},
|
||||
"binarization_multi_3": {
|
||||
'': "saved_model_2020_01_16/model_bin3",
|
||||
},
|
||||
"binarization_multi_4": {
|
||||
'': "saved_model_2020_01_16/model_bin4",
|
||||
},
|
||||
|
||||
"col_classifier": {
|
||||
'': "eynollah-column-classifier_20210425",
|
||||
},
|
||||
|
|
|
|||
|
|
@ -2,18 +2,19 @@
|
|||
Tool to load model and binarize a given image.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from glob import glob
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from ocrd_utils import tf_disable_interactive_logs
|
||||
|
||||
from eynollah.model_zoo import EynollahModelZoo
|
||||
tf_disable_interactive_logs()
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
|
||||
from .utils import is_image_filename
|
||||
|
|
@ -23,40 +24,37 @@ def resize_image(img_in, input_height, input_width):
|
|||
|
||||
class SbbBinarizer:
|
||||
|
||||
def __init__(self, model_dir, logger=None):
|
||||
self.model_dir = model_dir
|
||||
def __init__(self, model_dir, mode='single', logger=None):
|
||||
if mode not in ('single', 'multi'):
|
||||
raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}")
|
||||
self.log = logger if logger else logging.getLogger('SbbBinarizer')
|
||||
|
||||
self.start_new_session()
|
||||
|
||||
self.model_files = glob(self.model_dir+"/*/", recursive = True)
|
||||
|
||||
self.models = []
|
||||
for model_file in self.model_files:
|
||||
self.models.append(self.load_model(model_file))
|
||||
self.model_zoo = EynollahModelZoo(basedir=model_dir)
|
||||
self.models = self.setup_models(mode)
|
||||
self.session = self.start_new_session()
|
||||
|
||||
def start_new_session(self):
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(self.session)
|
||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||
tensorflow_backend.set_session(session)
|
||||
return session
|
||||
|
||||
def setup_models(self, mode: str) -> Dict[Path, Model]:
|
||||
return {
|
||||
self.model_zoo.model_path(v): self.model_zoo.load_model(v)
|
||||
for v in (['binarization'] if mode == 'single' else [f'binarization_multi_{i}' for i in range(1, 5)])
|
||||
}
|
||||
|
||||
def end_session(self):
|
||||
tensorflow_backend.clear_session()
|
||||
self.session.close()
|
||||
del self.session
|
||||
|
||||
def load_model(self, model_name):
|
||||
model = load_model(os.path.join(self.model_dir, model_name), compile=False)
|
||||
def predict(self, img, use_patches, n_batch_inference=5):
|
||||
model = self.model_zoo.get('binarization', Model)
|
||||
model_height = model.layers[len(model.layers)-1].output_shape[1]
|
||||
model_width = model.layers[len(model.layers)-1].output_shape[2]
|
||||
n_classes = model.layers[len(model.layers)-1].output_shape[3]
|
||||
return model, model_height, model_width, n_classes
|
||||
|
||||
def predict(self, model_in, img, use_patches, n_batch_inference=5):
|
||||
tensorflow_backend.set_session(self.session)
|
||||
model, model_height, model_width, n_classes = model_in
|
||||
|
||||
img_org_h = img.shape[0]
|
||||
img_org_w = img.shape[1]
|
||||
|
|
@ -324,8 +322,8 @@ class SbbBinarizer:
|
|||
if image_path is not None:
|
||||
image = cv2.imread(image_path)
|
||||
img_last = 0
|
||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
||||
for n, (model_file, model) in enumerate(self.models.items()):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
|
||||
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
|
|
@ -354,8 +352,8 @@ class SbbBinarizer:
|
|||
print(image_name,'image_name')
|
||||
image = cv2.imread(os.path.join(dir_in,image_name) )
|
||||
img_last = 0
|
||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
||||
for n, (model_file, model) in enumerate(self.models.items()):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
|
||||
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue