mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-06-09 04:09:54 +02:00
move keras-specific classes to utils.keras, clean up imports
This commit is contained in:
parent
d7a774ebd2
commit
2c9727f9c9
2 changed files with 69 additions and 64 deletions
|
@ -15,27 +15,27 @@ from typing import Optional
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
from multiprocessing import Process, Queue, cpu_count
|
||||
import gc
|
||||
from PIL.Image import Image
|
||||
from ocrd import OcrdPage
|
||||
from ocrd_utils import getLogger
|
||||
import cv2
|
||||
import numpy as np
|
||||
from scipy.signal import find_peaks
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
from qurator.eynollah.utils.keras import PatchEncoder, Patches
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
stderr = sys.stderr
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.keras.models import load_model
|
||||
sys.stderr = stderr
|
||||
tf.get_logger().setLevel("ERROR")
|
||||
warnings.filterwarnings("ignore")
|
||||
from scipy.signal import find_peaks
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
load_model = tf.keras.models.load_model
|
||||
# use tf1 compatibility for keras backend
|
||||
from tensorflow.compat.v1.keras.backend import set_session
|
||||
from tensorflow.keras import layers
|
||||
set_session = tf.compat.v1.keras.backend.set_session
|
||||
|
||||
from .utils.contour import (
|
||||
filter_contours_area_of_image,
|
||||
|
@ -89,61 +89,8 @@ DPI_THRESHOLD = 298
|
|||
MAX_SLOPE = 999
|
||||
KERNEL = np.ones((5, 5), np.uint8)
|
||||
|
||||
projection_dim = 64
|
||||
patch_size = 1
|
||||
num_patches =21*21#14*14#28*28#14*14#28*28
|
||||
class Eynollah():
|
||||
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(PatchEncoder, self).__init__()
|
||||
self.num_patches = num_patches
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(
|
||||
input_dim=num_patches, output_dim=projection_dim
|
||||
)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': self.num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
class Eynollah:
|
||||
def __init__(
|
||||
self,
|
||||
dir_models : str,
|
||||
|
@ -673,7 +620,7 @@ class Eynollah:
|
|||
model = load_model(model_dir, compile=False)
|
||||
self.models[model_dir] = model
|
||||
except:
|
||||
model = load_model(model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
self.models[model_dir] = model
|
||||
|
||||
|
||||
|
|
58
qurator/eynollah/utils/keras.py
Normal file
58
qurator/eynollah/utils/keras.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import tensorflow as tf
|
||||
layers = tf.keras.layers
|
||||
|
||||
PROJECTION_DIM = 64
|
||||
PATCH_SIZE = 1
|
||||
NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28
|
||||
|
||||
class Patches(layers.Layer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.patch_size = PATCH_SIZE
|
||||
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
batch_size = tf.shape(inputs)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=inputs,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_patches = NUM_PATCHES
|
||||
self.projection = layers.Dense(units=PROJECTION_DIM)
|
||||
self.position_embedding = layers.Embedding(input_dim=NUM_PATCHES, output_dim=PROJECTION_DIM)
|
||||
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
# XXX: pyright thinks self.projection(inputs) is None
|
||||
encoded = self.projection(inputs) + self.position_embedding(positions)
|
||||
return encoded
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': self.num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue