move keras-specific classes to utils.keras, clean up imports

refactoring-2024-08-merged
kba 5 months ago
parent d7a774ebd2
commit 2c9727f9c9

@ -15,27 +15,27 @@ from typing import Optional
import warnings
from pathlib import Path
from multiprocessing import Process, Queue, cpu_count
import gc
from PIL.Image import Image
from ocrd import OcrdPage
from ocrd_utils import getLogger
import cv2
import numpy as np
from scipy.signal import find_peaks
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from qurator.eynollah.utils.keras import PatchEncoder, Patches
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
stderr = sys.stderr
sys.stderr = open(os.devnull, "w")
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.keras.models import load_model
sys.stderr = stderr
tf.get_logger().setLevel("ERROR")
warnings.filterwarnings("ignore")
from scipy.signal import find_peaks
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
load_model = tf.keras.models.load_model
# use tf1 compatibility for keras backend
from tensorflow.compat.v1.keras.backend import set_session
from tensorflow.keras import layers
set_session = tf.compat.v1.keras.backend.set_session
from .utils.contour import (
filter_contours_area_of_image,
@ -89,61 +89,8 @@ DPI_THRESHOLD = 298
MAX_SLOPE = 999
KERNEL = np.ones((5, 5), np.uint8)
projection_dim = 64
patch_size = 1
num_patches =21*21#14*14#28*28#14*14#28*28
class Patches(layers.Layer):
def __init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
class Eynollah:
class Eynollah():
def __init__(
self,
dir_models : str,
@ -673,7 +620,7 @@ class Eynollah:
model = load_model(model_dir, compile=False)
self.models[model_dir] = model
except:
model = load_model(model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
self.models[model_dir] = model

@ -0,0 +1,58 @@
import tensorflow as tf
layers = tf.keras.layers
PROJECTION_DIM = 64
PATCH_SIZE = 1
NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28
class Patches(layers.Layer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.patch_size = PATCH_SIZE
def call(self, inputs, *args, **kwargs):
batch_size = tf.shape(inputs)[0]
patches = tf.image.extract_patches(
images=inputs,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_patches = NUM_PATCHES
self.projection = layers.Dense(units=PROJECTION_DIM)
self.position_embedding = layers.Embedding(input_dim=NUM_PATCHES, output_dim=PROJECTION_DIM)
def call(self, inputs, *args, **kwargs):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
# XXX: pyright thinks self.projection(inputs) is None
encoded = self.projection(inputs) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
Loading…
Cancel
Save