From bcffa2e5035356c638a138f9c09f67f23ec02b06 Mon Sep 17 00:00:00 2001 From: kba Date: Tue, 21 Oct 2025 17:53:24 +0200 Subject: [PATCH] adopt binarizer to the zoo --- src/eynollah/cli.py | 25 +++++++++++++++-- src/eynollah/model_zoo.py | 13 +++++++++ src/eynollah/sbb_binarize.py | 54 +++++++++++++++++------------------- 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index 14ae77d..c7d4bd9 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -120,18 +120,39 @@ def machine_based_reading_order(input, dir_in, out, model, log_level): type=click.Path(file_okay=True, dir_okay=True), required=True, ) +@click.option( + '-M', + '--mode', + type=click.Choice(['single', 'multi']), + default='single', + help="Whether to use the (faster) single-model binarization or the (slightly better) multi-model binarization" +) @click.option( "--log_level", "-l", type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), help="Override log level globally to this", ) -def binarization(patches, model_dir, input_image, dir_in, output, log_level): +def binarization( + patches, + model_dir, + input_image, + mode, + dir_in, + output, + log_level, +): assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." binarizer = SbbBinarizer(model_dir) if log_level: binarizer.log.setLevel(getLevelName(log_level)) - binarizer.run(image_path=input_image, use_patches=patches, output=output, dir_in=dir_in) + binarizer.run( + image_path=input_image, + use_patches=patches, + mode=mode, + output=output, + dir_in=dir_in + ) @main.command() diff --git a/src/eynollah/model_zoo.py b/src/eynollah/model_zoo.py index 7f90bc0..100d974 100644 --- a/src/eynollah/model_zoo.py +++ b/src/eynollah/model_zoo.py @@ -25,6 +25,19 @@ DEFAULT_MODEL_VERSIONS: Dict[str, Dict[str, str]] = { '': "eynollah-binarization_20210425" }, + "binarization_multi_1": { + '': "saved_model_2020_01_16/model_bin1", + }, + "binarization_multi_2": { + '': "saved_model_2020_01_16/model_bin2", + }, + "binarization_multi_3": { + '': "saved_model_2020_01_16/model_bin3", + }, + "binarization_multi_4": { + '': "saved_model_2020_01_16/model_bin4", + }, + "col_classifier": { '': "eynollah-column-classifier_20210425", }, diff --git a/src/eynollah/sbb_binarize.py b/src/eynollah/sbb_binarize.py index 3716987..f8898a1 100644 --- a/src/eynollah/sbb_binarize.py +++ b/src/eynollah/sbb_binarize.py @@ -2,18 +2,19 @@ Tool to load model and binarize a given image. """ -import sys -from glob import glob import os import logging +from pathlib import Path +from typing import Dict, List +from keras.models import Model import numpy as np -from PIL import Image import cv2 from ocrd_utils import tf_disable_interactive_logs + +from eynollah.model_zoo import EynollahModelZoo tf_disable_interactive_logs() import tensorflow as tf -from tensorflow.keras.models import load_model from tensorflow.python.keras import backend as tensorflow_backend from .utils import is_image_filename @@ -23,40 +24,37 @@ def resize_image(img_in, input_height, input_width): class SbbBinarizer: - def __init__(self, model_dir, logger=None): - self.model_dir = model_dir + def __init__(self, model_dir, mode='single', logger=None): + if mode not in ('single', 'multi'): + raise ValueError(f"'mode' must be either 'multi' or 'single', not {mode}") self.log = logger if logger else logging.getLogger('SbbBinarizer') - - self.start_new_session() - - self.model_files = glob(self.model_dir+"/*/", recursive = True) - - self.models = [] - for model_file in self.model_files: - self.models.append(self.load_model(model_file)) + self.model_zoo = EynollahModelZoo(basedir=model_dir) + self.models = self.setup_models(mode) + self.session = self.start_new_session() def start_new_session(self): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() - tensorflow_backend.set_session(self.session) + session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + tensorflow_backend.set_session(session) + return session + + def setup_models(self, mode: str) -> Dict[Path, Model]: + return { + self.model_zoo.model_path(v): self.model_zoo.load_model(v) + for v in (['binarization'] if mode == 'single' else [f'binarization_multi_{i}' for i in range(1, 5)]) + } def end_session(self): tensorflow_backend.clear_session() self.session.close() del self.session - def load_model(self, model_name): - model = load_model(os.path.join(self.model_dir, model_name), compile=False) + def predict(self, img, use_patches, n_batch_inference=5): + model = self.model_zoo.get('binarization', Model) model_height = model.layers[len(model.layers)-1].output_shape[1] model_width = model.layers[len(model.layers)-1].output_shape[2] - n_classes = model.layers[len(model.layers)-1].output_shape[3] - return model, model_height, model_width, n_classes - - def predict(self, model_in, img, use_patches, n_batch_inference=5): - tensorflow_backend.set_session(self.session) - model, model_height, model_width, n_classes = model_in img_org_h = img.shape[0] img_org_w = img.shape[1] @@ -324,8 +322,8 @@ class SbbBinarizer: if image_path is not None: image = cv2.imread(image_path) img_last = 0 - for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): - self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) + for n, (model_file, model) in enumerate(self.models.items()): + self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys()))) res = self.predict(model, image, use_patches) @@ -354,8 +352,8 @@ class SbbBinarizer: print(image_name,'image_name') image = cv2.imread(os.path.join(dir_in,image_name) ) img_last = 0 - for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): - self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) + for n, (model_file, model) in enumerate(self.models.items()): + self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.models.keys()))) res = self.predict(model, image, use_patches)