From 2c9727f9c9d81c097a6c68c26dcda9ee6824ed7b Mon Sep 17 00:00:00 2001 From: kba Date: Fri, 23 Aug 2024 19:53:04 +0200 Subject: [PATCH] move keras-specific classes to utils.keras, clean up imports --- qurator/eynollah/eynollah.py | 77 +++++---------------------------- qurator/eynollah/utils/keras.py | 58 +++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 65 deletions(-) create mode 100644 qurator/eynollah/utils/keras.py diff --git a/qurator/eynollah/eynollah.py b/qurator/eynollah/eynollah.py index bf0cc88..3d08eb4 100644 --- a/qurator/eynollah/eynollah.py +++ b/qurator/eynollah/eynollah.py @@ -15,27 +15,27 @@ from typing import Optional import warnings from pathlib import Path from multiprocessing import Process, Queue, cpu_count -import gc from PIL.Image import Image from ocrd import OcrdPage -from ocrd_utils import getLogger import cv2 import numpy as np +from scipy.signal import find_peaks +import matplotlib.pyplot as plt +from scipy.ndimage import gaussian_filter1d + +from qurator.eynollah.utils.keras import PatchEncoder, Patches + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" stderr = sys.stderr sys.stderr = open(os.devnull, "w") import tensorflow as tf -from tensorflow.python.keras import backend as K -from tensorflow.keras.models import load_model sys.stderr = stderr tf.get_logger().setLevel("ERROR") warnings.filterwarnings("ignore") -from scipy.signal import find_peaks -import matplotlib.pyplot as plt -from scipy.ndimage import gaussian_filter1d + +load_model = tf.keras.models.load_model # use tf1 compatibility for keras backend -from tensorflow.compat.v1.keras.backend import set_session -from tensorflow.keras import layers +set_session = tf.compat.v1.keras.backend.set_session from .utils.contour import ( filter_contours_area_of_image, @@ -89,61 +89,8 @@ DPI_THRESHOLD = 298 MAX_SLOPE = 999 KERNEL = np.ones((5, 5), np.uint8) -projection_dim = 64 -patch_size = 1 -num_patches =21*21#14*14#28*28#14*14#28*28 - - -class Patches(layers.Layer): - def __init__(self, **kwargs): - super(Patches, self).__init__() - self.patch_size = patch_size - - def call(self, images): - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size, self.patch_size, 1], - strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - patch_dims = patches.shape[-1] - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size': self.patch_size, - }) - return config - - -class PatchEncoder(layers.Layer): - def __init__(self, **kwargs): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config - -class Eynollah: +class Eynollah(): + def __init__( self, dir_models : str, @@ -673,7 +620,7 @@ class Eynollah: model = load_model(model_dir, compile=False) self.models[model_dir] = model except: - model = load_model(model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) + model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) self.models[model_dir] = model diff --git a/qurator/eynollah/utils/keras.py b/qurator/eynollah/utils/keras.py new file mode 100644 index 0000000..f5da4d2 --- /dev/null +++ b/qurator/eynollah/utils/keras.py @@ -0,0 +1,58 @@ +import tensorflow as tf +layers = tf.keras.layers + +PROJECTION_DIM = 64 +PATCH_SIZE = 1 +NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28 + +class Patches(layers.Layer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.patch_size = PATCH_SIZE + + def call(self, inputs, *args, **kwargs): + batch_size = tf.shape(inputs)[0] + patches = tf.image.extract_patches( + images=inputs, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + + def get_config(self): + config = super().get_config().copy() + config.update({ + 'patch_size': self.patch_size, + }) + return config + + +class PatchEncoder(layers.Layer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_patches = NUM_PATCHES + self.projection = layers.Dense(units=PROJECTION_DIM) + self.position_embedding = layers.Embedding(input_dim=NUM_PATCHES, output_dim=PROJECTION_DIM) + + def call(self, inputs, *args, **kwargs): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + # XXX: pyright thinks self.projection(inputs) is None + encoded = self.projection(inputs) + self.position_embedding(positions) + return encoded + + def get_config(self): + config = super().get_config().copy() + config.update({ + 'num_patches': self.num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + +