adopt binarizer to the zoo

This commit is contained in:
kba 2025-10-21 17:53:24 +02:00
parent de34a15809
commit bcffa2e503
3 changed files with 62 additions and 30 deletions

View file

@ -120,18 +120,39 @@ def machine_based_reading_order(input, dir_in, out, model, log_level):
type=click.Path(file_okay=True, dir_okay=True), type=click.Path(file_okay=True, dir_okay=True),
required=True, required=True,
) )
@click.option(
'-M',
'--mode',
type=click.Choice(['single', 'multi']),
default='single',
help="Whether to use the (faster) single-model binarization or the (slightly better) multi-model binarization"
)
@click.option( @click.option(
"--log_level", "--log_level",
"-l", "-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this", help="Override log level globally to this",
) )
def binarization(patches, model_dir, input_image, dir_in, output, log_level): def binarization(
patches,
model_dir,
input_image,
mode,
dir_in,
output,
log_level,
):
assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
binarizer = SbbBinarizer(model_dir) binarizer = SbbBinarizer(model_dir)
if log_level: if log_level:
binarizer.log.setLevel(getLevelName(log_level)) binarizer.log.setLevel(getLevelName(log_level))
binarizer.run(image_path=input_image, use_patches=patches, output=output, dir_in=dir_in) binarizer.run(
image_path=input_image,
use_patches=patches,
mode=mode,
output=output,
dir_in=dir_in
)
@main.command() @main.command()

View file

@ -25,6 +25,19 @@ DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = {
'': "eynollah-binarization_20210425" '': "eynollah-binarization_20210425"
}, },
"binarization_multi_1": {
'': "saved_model_2020_01_16/model_bin1",
},
"binarization_multi_2": {
'': "saved_model_2020_01_16/model_bin2",
},
"binarization_multi_3": {
'': "saved_model_2020_01_16/model_bin3",
},
"binarization_multi_4": {
'': "saved_model_2020_01_16/model_bin4",
},
"col_classifier": { "col_classifier": {
'': "eynollah-column-classifier_20210425", '': "eynollah-column-classifier_20210425",
}, },

View file

@ -2,18 +2,19 @@
Tool to load model and binarize a given image. Tool to load model and binarize a given image.
""" """
import sys
from glob import glob
import os import os
import logging import logging
from pathlib import Path
from typing import Dict, List
from keras.models import Model
import numpy as np import numpy as np
from PIL import Image
import cv2 import cv2
from ocrd_utils import tf_disable_interactive_logs from ocrd_utils import tf_disable_interactive_logs
from eynollah.model_zoo import EynollahModelZoo
tf_disable_interactive_logs() tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend from tensorflow.python.keras import backend as tensorflow_backend
from .utils import is_image_filename from .utils import is_image_filename
@ -23,40 +24,37 @@ def resize_image(img_in, input_height, input_width):
class SbbBinarizer: class SbbBinarizer:
def __init__(self, model_dir, logger=None): def __init__(self, model_dir, mode='single', logger=None):
self.model_dir = model_dir if mode not in ('single', 'multi'):
raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}")
self.log = logger if logger else logging.getLogger('SbbBinarizer') self.log = logger if logger else logging.getLogger('SbbBinarizer')
self.model_zoo = EynollahModelZoo(basedir=model_dir)
self.start_new_session() self.models = self.setup_models(mode)
self.session = self.start_new_session()
self.model_files = glob(self.model_dir+"/*/", recursive = True)
self.models = []
for model_file in self.model_files:
self.models.append(self.load_model(model_file))
def start_new_session(self): def start_new_session(self):
config = tf.compat.v1.ConfigProto() config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(self.session) tensorflow_backend.set_session(session)
return session
def setup_models(self, mode: str) -> Dict[Path, Model]:
return {
self.model_zoo.model_path(v): self.model_zoo.load_model(v)
for v in (['binarization'] if mode == 'single' else [f'binarization_multi_{i}' for i in range(1, 5)])
}
def end_session(self): def end_session(self):
tensorflow_backend.clear_session() tensorflow_backend.clear_session()
self.session.close() self.session.close()
del self.session del self.session
def load_model(self, model_name): def predict(self, img, use_patches, n_batch_inference=5):
model = load_model(os.path.join(self.model_dir, model_name), compile=False) model = self.model_zoo.get('binarization', Model)
model_height = model.layers[len(model.layers)-1].output_shape[1] model_height = model.layers[len(model.layers)-1].output_shape[1]
model_width = model.layers[len(model.layers)-1].output_shape[2] model_width = model.layers[len(model.layers)-1].output_shape[2]
n_classes = model.layers[len(model.layers)-1].output_shape[3]
return model, model_height, model_width, n_classes
def predict(self, model_in, img, use_patches, n_batch_inference=5):
tensorflow_backend.set_session(self.session)
model, model_height, model_width, n_classes = model_in
img_org_h = img.shape[0] img_org_h = img.shape[0]
img_org_w = img.shape[1] img_org_w = img.shape[1]
@ -324,8 +322,8 @@ class SbbBinarizer:
if image_path is not None: if image_path is not None:
image = cv2.imread(image_path) image = cv2.imread(image_path)
img_last = 0 img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): for n, (model_file, model) in enumerate(self.models.items()):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
res = self.predict(model, image, use_patches) res = self.predict(model, image, use_patches)
@ -354,8 +352,8 @@ class SbbBinarizer:
print(image_name,'image_name') print(image_name,'image_name')
image = cv2.imread(os.path.join(dir_in,image_name) ) image = cv2.imread(os.path.join(dir_in,image_name) )
img_last = 0 img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): for n, (model_file, model) in enumerate(self.models.items()):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys())))
res = self.predict(model, image, use_patches) res = self.predict(model, image, use_patches)