adopt image_enhancer to the zoo

This commit is contained in:
kba 2025-10-21 19:24:55 +02:00
parent f0c86672f8
commit 1337461d47

View file

@ -5,17 +5,18 @@ Image enhancer. The output can be written as same scale of input or in new predi
from logging import Logger
import os
import time
from typing import Optional
from typing import Dict, Optional
from pathlib import Path
import gc
import cv2
from keras.models import Model
import numpy as np
from ocrd_utils import getLogger, tf_disable_interactive_logs
import tensorflow as tf
from skimage.morphology import skeletonize
from tensorflow.keras.models import load_model
from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image
from .utils.pil_cv2 import pil2cv
from .utils import (
@ -50,11 +51,9 @@ class Enhancer:
self.num_col_lower = num_col_lower
self.logger = logger if logger else getLogger('enhancement')
self.dir_models = dir_models
self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425"
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915"
self.model_zoo = EynollahModelZoo(basedir=dir_models)
for v in ['binarization', 'enhancement', 'col_classifier', 'page']:
self.model_zoo.load_model(v)
try:
for device in tf.config.list_physical_devices('GPU'):
@ -62,11 +61,6 @@ class Enhancer:
except:
self.logger.warning("no GPU device available")
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
ret = {}
if image_filename:
@ -102,24 +96,12 @@ class Enhancer:
def isNaN(self, num):
return num != num
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
model = load_model(model_file, compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
return model
def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement")
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1]
img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2]
if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < img_width_model:
@ -160,7 +142,7 @@ class Enhancer:
index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose=0)
seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0:
@ -246,7 +228,7 @@ class Enhancer:
else:
img = self.imread()
img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page'))
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
@ -291,7 +273,7 @@ class Enhancer:
self.logger.info("Detected %s DPI", dpi)
if self.input_binary:
img = self.imread()
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
img= np.copy(prediction_bin)
@ -332,7 +314,7 @@ class Enhancer:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
if self.input_binary:
@ -352,7 +334,7 @@ class Enhancer:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
if num_col > self.num_col_upper: