adopt image_enhancer to the zoo

This commit is contained in:
kba 2025-10-21 19:24:55 +02:00
parent f0c86672f8
commit 1337461d47

View file

@ -5,17 +5,18 @@ Image enhancer. The output can be written as same scale of input or in new predi
from logging import Logger from logging import Logger
import os import os
import time import time
from typing import Optional from typing import Dict, Optional
from pathlib import Path from pathlib import Path
import gc import gc
import cv2 import cv2
from keras.models import Model
import numpy as np import numpy as np
from ocrd_utils import getLogger, tf_disable_interactive_logs from ocrd_utils import getLogger, tf_disable_interactive_logs
import tensorflow as tf import tensorflow as tf
from skimage.morphology import skeletonize from skimage.morphology import skeletonize
from tensorflow.keras.models import load_model
from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image from .utils.resize import resize_image
from .utils.pil_cv2 import pil2cv from .utils.pil_cv2 import pil2cv
from .utils import ( from .utils import (
@ -50,11 +51,9 @@ class Enhancer:
self.num_col_lower = num_col_lower self.num_col_lower = num_col_lower
self.logger = logger if logger else getLogger('enhancement') self.logger = logger if logger else getLogger('enhancement')
self.dir_models = dir_models self.model_zoo = EynollahModelZoo(basedir=dir_models)
self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425" for v in ['binarization', 'enhancement', 'col_classifier', 'page']:
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425" self.model_zoo.load_model(v)
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915"
try: try:
for device in tf.config.list_physical_devices('GPU'): for device in tf.config.list_physical_devices('GPU'):
@ -62,11 +61,6 @@ class Enhancer:
except: except:
self.logger.warning("no GPU device available") self.logger.warning("no GPU device available")
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
def cache_images(self, image_filename=None, image_pil=None, dpi=None): def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {} ret = {}
if image_filename: if image_filename:
@ -103,23 +97,11 @@ class Enhancer:
def isNaN(self, num): def isNaN(self, num):
return num != num return num != num
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def predict_enhancement(self, img): def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement") self.logger.debug("enter predict_enhancement")
img_height_model = self.model_enhancement.layers[-1].output_shape[1] img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1]
img_width_model = self.model_enhancement.layers[-1].output_shape[2] img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2]
if img.shape[0] < img_height_model: if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < img_width_model: if img.shape[1] < img_width_model:
@ -160,7 +142,7 @@ class Enhancer:
index_y_d = img_h - img_height_model index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0) label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose=0)
seg = label_p_pred[0, :, :, :] * 255 seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0: if i == 0 and j == 0:
@ -246,7 +228,7 @@ class Enhancer:
else: else:
img = self.imread() img = self.imread()
img = cv2.GaussianBlur(img, (5, 5), 0) img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page) img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page'))
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0) _, thresh = cv2.threshold(imgray, 0, 255, 0)
@ -291,7 +273,7 @@ class Enhancer:
self.logger.info("Detected %s DPI", dpi) self.logger.info("Detected %s DPI", dpi)
if self.input_binary: if self.input_binary:
img = self.imread() img = self.imread()
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5) prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0]==0) prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8) prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
img= np.copy(prediction_bin) img= np.copy(prediction_bin)
@ -332,7 +314,7 @@ class Enhancer:
img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0) label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1 num_col = np.argmax(label_p_pred[0]) + 1
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
if self.input_binary: if self.input_binary:
@ -352,7 +334,7 @@ class Enhancer:
img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0) label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1 num_col = np.argmax(label_p_pred[0]) + 1
if num_col > self.num_col_upper: if num_col > self.num_col_upper: