From 1337461d478c349ec0909c9772419f991bbd4cce Mon Sep 17 00:00:00 2001 From: kba Date: Tue, 21 Oct 2025 19:24:55 +0200 Subject: [PATCH] adopt image_enhancer to the zoo --- src/eynollah/image_enhancer.py | 44 ++++++++++------------------------ 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/src/eynollah/image_enhancer.py b/src/eynollah/image_enhancer.py index 93b5daa..cec8877 100644 --- a/src/eynollah/image_enhancer.py +++ b/src/eynollah/image_enhancer.py @@ -5,17 +5,18 @@ Image enhancer. The output can be written as same scale of input or in new predi from logging import Logger import os import time -from typing import Optional +from typing import Dict, Optional from pathlib import Path import gc import cv2 +from keras.models import Model import numpy as np from ocrd_utils import getLogger, tf_disable_interactive_logs import tensorflow as tf from skimage.morphology import skeletonize -from tensorflow.keras.models import load_model +from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.pil_cv2 import pil2cv from .utils import ( @@ -50,11 +51,9 @@ class Enhancer: self.num_col_lower = num_col_lower self.logger = logger if logger else getLogger('enhancement') - self.dir_models = dir_models - self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425" - self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425" - self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425" - self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915" + self.model_zoo = EynollahModelZoo(basedir=dir_models) + for v in ['binarization', 'enhancement', 'col_classifier', 'page']: + self.model_zoo.load_model(v) try: for device in tf.config.list_physical_devices('GPU'): @@ -62,11 +61,6 @@ class Enhancer: except: self.logger.warning("no GPU device available") - self.model_page = self.our_load_model(self.model_page_dir) - self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier) - self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) - self.model_bin = self.our_load_model(self.model_dir_of_binarization) - def cache_images(self, image_filename=None, image_pil=None, dpi=None): ret = {} if image_filename: @@ -102,24 +96,12 @@ class Enhancer: def isNaN(self, num): return num != num - - @staticmethod - def our_load_model(model_file): - if model_file.endswith('.h5') and Path(model_file[:-3]).exists(): - # prefer SavedModel over HDF5 format if it exists - model_file = model_file[:-3] - try: - model = load_model(model_file, compile=False) - except: - model = load_model(model_file, compile=False, custom_objects={ - "PatchEncoder": PatchEncoder, "Patches": Patches}) - return model def predict_enhancement(self, img): self.logger.debug("enter predict_enhancement") - img_height_model = self.model_enhancement.layers[-1].output_shape[1] - img_width_model = self.model_enhancement.layers[-1].output_shape[2] + img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1] + img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2] if img.shape[0] < img_height_model: img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) if img.shape[1] < img_width_model: @@ -160,7 +142,7 @@ class Enhancer: index_y_d = img_h - img_height_model img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) + label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose=0) seg = label_p_pred[0, :, :, :] * 255 if i == 0 and j == 0: @@ -246,7 +228,7 @@ class Enhancer: else: img = self.imread() img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(False, img, self.model_page) + img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page')) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -291,7 +273,7 @@ class Enhancer: self.logger.info("Detected %s DPI", dpi) if self.input_binary: img = self.imread() - prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5) + prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5) prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) img= np.copy(prediction_bin) @@ -332,7 +314,7 @@ class Enhancer: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): if self.input_binary: @@ -352,7 +334,7 @@ class Enhancer: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0) num_col = np.argmax(label_p_pred[0]) + 1 if num_col > self.num_col_upper: