mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
adopt image_enhancer to the zoo
This commit is contained in:
parent
f0c86672f8
commit
1337461d47
1 changed files with 13 additions and 31 deletions
|
|
@ -5,17 +5,18 @@ Image enhancer. The output can be written as same scale of input or in new predi
|
|||
from logging import Logger
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
from pathlib import Path
|
||||
import gc
|
||||
|
||||
import cv2
|
||||
from keras.models import Model
|
||||
import numpy as np
|
||||
from ocrd_utils import getLogger, tf_disable_interactive_logs
|
||||
import tensorflow as tf
|
||||
from skimage.morphology import skeletonize
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
from .model_zoo import EynollahModelZoo
|
||||
from .utils.resize import resize_image
|
||||
from .utils.pil_cv2 import pil2cv
|
||||
from .utils import (
|
||||
|
|
@ -50,11 +51,9 @@ class Enhancer:
|
|||
self.num_col_lower = num_col_lower
|
||||
|
||||
self.logger = logger if logger else getLogger('enhancement')
|
||||
self.dir_models = dir_models
|
||||
self.model_dir_of_binarization = dir_models + "/eynollah-binarization_20210425"
|
||||
self.model_dir_of_enhancement = dir_models + "/eynollah-enhancement_20210425"
|
||||
self.model_dir_of_col_classifier = dir_models + "/eynollah-column-classifier_20210425"
|
||||
self.model_page_dir = dir_models + "/model_eynollah_page_extraction_20250915"
|
||||
self.model_zoo = EynollahModelZoo(basedir=dir_models)
|
||||
for v in ['binarization', 'enhancement', 'col_classifier', 'page']:
|
||||
self.model_zoo.load_model(v)
|
||||
|
||||
try:
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
|
|
@ -62,11 +61,6 @@ class Enhancer:
|
|||
except:
|
||||
self.logger.warning("no GPU device available")
|
||||
|
||||
self.model_page = self.our_load_model(self.model_page_dir)
|
||||
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
|
||||
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
|
||||
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
|
||||
|
||||
def cache_images(self, image_filename=None, image_pil=None, dpi=None):
|
||||
ret = {}
|
||||
if image_filename:
|
||||
|
|
@ -102,24 +96,12 @@ class Enhancer:
|
|||
|
||||
def isNaN(self, num):
|
||||
return num != num
|
||||
|
||||
@staticmethod
|
||||
def our_load_model(model_file):
|
||||
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
|
||||
# prefer SavedModel over HDF5 format if it exists
|
||||
model_file = model_file[:-3]
|
||||
try:
|
||||
model = load_model(model_file, compile=False)
|
||||
except:
|
||||
model = load_model(model_file, compile=False, custom_objects={
|
||||
"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
return model
|
||||
|
||||
def predict_enhancement(self, img):
|
||||
self.logger.debug("enter predict_enhancement")
|
||||
|
||||
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
|
||||
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
|
||||
img_height_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[1]
|
||||
img_width_model = self.model_zoo.get('enhancement', Model).layers[-1].output_shape[2]
|
||||
if img.shape[0] < img_height_model:
|
||||
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
|
||||
if img.shape[1] < img_width_model:
|
||||
|
|
@ -160,7 +142,7 @@ class Enhancer:
|
|||
index_y_d = img_h - img_height_model
|
||||
|
||||
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_p_pred = self.model_enhancement.predict(img_patch, verbose=0)
|
||||
label_p_pred = self.model_zoo.get('enhancement', Model).predict(img_patch, verbose=0)
|
||||
seg = label_p_pred[0, :, :, :] * 255
|
||||
|
||||
if i == 0 and j == 0:
|
||||
|
|
@ -246,7 +228,7 @@ class Enhancer:
|
|||
else:
|
||||
img = self.imread()
|
||||
img = cv2.GaussianBlur(img, (5, 5), 0)
|
||||
img_page_prediction = self.do_prediction(False, img, self.model_page)
|
||||
img_page_prediction = self.do_prediction(False, img, self.model_zoo.get('page'))
|
||||
|
||||
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
|
||||
_, thresh = cv2.threshold(imgray, 0, 255, 0)
|
||||
|
|
@ -291,7 +273,7 @@ class Enhancer:
|
|||
self.logger.info("Detected %s DPI", dpi)
|
||||
if self.input_binary:
|
||||
img = self.imread()
|
||||
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
|
||||
prediction_bin = self.do_prediction(True, img, self.model_zoo.get('binarization'), n_batch_inference=5)
|
||||
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
|
||||
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
|
||||
img= np.copy(prediction_bin)
|
||||
|
|
@ -332,7 +314,7 @@ class Enhancer:
|
|||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||
|
||||
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||
label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
|
||||
num_col = np.argmax(label_p_pred[0]) + 1
|
||||
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
|
||||
if self.input_binary:
|
||||
|
|
@ -352,7 +334,7 @@ class Enhancer:
|
|||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||
|
||||
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
|
||||
label_p_pred = self.model_zoo.get('col_classifier').predict(img_in, verbose=0)
|
||||
num_col = np.argmax(label_p_pred[0]) + 1
|
||||
|
||||
if num_col > self.num_col_upper:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue