mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-06-09 04:09:54 +02:00
move keras-specific classes to utils.keras, clean up imports
This commit is contained in:
parent
d7a774ebd2
commit
2c9727f9c9
2 changed files with 69 additions and 64 deletions
|
@ -15,27 +15,27 @@ from typing import Optional
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from multiprocessing import Process, Queue, cpu_count
|
from multiprocessing import Process, Queue, cpu_count
|
||||||
import gc
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from ocrd import OcrdPage
|
from ocrd import OcrdPage
|
||||||
from ocrd_utils import getLogger
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy.signal import find_peaks
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from scipy.ndimage import gaussian_filter1d
|
||||||
|
|
||||||
|
from qurator.eynollah.utils.keras import PatchEncoder, Patches
|
||||||
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
stderr = sys.stderr
|
stderr = sys.stderr
|
||||||
sys.stderr = open(os.devnull, "w")
|
sys.stderr = open(os.devnull, "w")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.keras import backend as K
|
|
||||||
from tensorflow.keras.models import load_model
|
|
||||||
sys.stderr = stderr
|
sys.stderr = stderr
|
||||||
tf.get_logger().setLevel("ERROR")
|
tf.get_logger().setLevel("ERROR")
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
from scipy.signal import find_peaks
|
|
||||||
import matplotlib.pyplot as plt
|
load_model = tf.keras.models.load_model
|
||||||
from scipy.ndimage import gaussian_filter1d
|
|
||||||
# use tf1 compatibility for keras backend
|
# use tf1 compatibility for keras backend
|
||||||
from tensorflow.compat.v1.keras.backend import set_session
|
set_session = tf.compat.v1.keras.backend.set_session
|
||||||
from tensorflow.keras import layers
|
|
||||||
|
|
||||||
from .utils.contour import (
|
from .utils.contour import (
|
||||||
filter_contours_area_of_image,
|
filter_contours_area_of_image,
|
||||||
|
@ -89,61 +89,8 @@ DPI_THRESHOLD = 298
|
||||||
MAX_SLOPE = 999
|
MAX_SLOPE = 999
|
||||||
KERNEL = np.ones((5, 5), np.uint8)
|
KERNEL = np.ones((5, 5), np.uint8)
|
||||||
|
|
||||||
projection_dim = 64
|
class Eynollah():
|
||||||
patch_size = 1
|
|
||||||
num_patches =21*21#14*14#28*28#14*14#28*28
|
|
||||||
|
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(Patches, self).__init__()
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
def call(self, images):
|
|
||||||
batch_size = tf.shape(images)[0]
|
|
||||||
patches = tf.image.extract_patches(
|
|
||||||
images=images,
|
|
||||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
|
||||||
strides=[1, self.patch_size, self.patch_size, 1],
|
|
||||||
rates=[1, 1, 1, 1],
|
|
||||||
padding="VALID",
|
|
||||||
)
|
|
||||||
patch_dims = patches.shape[-1]
|
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
|
||||||
return patches
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'patch_size': self.patch_size,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(PatchEncoder, self).__init__()
|
|
||||||
self.num_patches = num_patches
|
|
||||||
self.projection = layers.Dense(units=projection_dim)
|
|
||||||
self.position_embedding = layers.Embedding(
|
|
||||||
input_dim=num_patches, output_dim=projection_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, patch):
|
|
||||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
|
||||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
|
||||||
return encoded
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'num_patches': self.num_patches,
|
|
||||||
'projection': self.projection,
|
|
||||||
'position_embedding': self.position_embedding,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
class Eynollah:
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dir_models : str,
|
dir_models : str,
|
||||||
|
@ -673,7 +620,7 @@ class Eynollah:
|
||||||
model = load_model(model_dir, compile=False)
|
model = load_model(model_dir, compile=False)
|
||||||
self.models[model_dir] = model
|
self.models[model_dir] = model
|
||||||
except:
|
except:
|
||||||
model = load_model(model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
self.models[model_dir] = model
|
self.models[model_dir] = model
|
||||||
|
|
||||||
|
|
||||||
|
|
58
qurator/eynollah/utils/keras.py
Normal file
58
qurator/eynollah/utils/keras.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
layers = tf.keras.layers
|
||||||
|
|
||||||
|
PROJECTION_DIM = 64
|
||||||
|
PATCH_SIZE = 1
|
||||||
|
NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28
|
||||||
|
|
||||||
|
class Patches(layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.patch_size = PATCH_SIZE
|
||||||
|
|
||||||
|
def call(self, inputs, *args, **kwargs):
|
||||||
|
batch_size = tf.shape(inputs)[0]
|
||||||
|
patches = tf.image.extract_patches(
|
||||||
|
images=inputs,
|
||||||
|
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||||
|
strides=[1, self.patch_size, self.patch_size, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
)
|
||||||
|
patch_dims = patches.shape[-1]
|
||||||
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
|
return patches
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'patch_size': self.patch_size,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEncoder(layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_patches = NUM_PATCHES
|
||||||
|
self.projection = layers.Dense(units=PROJECTION_DIM)
|
||||||
|
self.position_embedding = layers.Embedding(input_dim=NUM_PATCHES, output_dim=PROJECTION_DIM)
|
||||||
|
|
||||||
|
def call(self, inputs, *args, **kwargs):
|
||||||
|
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||||
|
# XXX: pyright thinks self.projection(inputs) is None
|
||||||
|
encoded = self.projection(inputs) + self.position_embedding(positions)
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'num_patches': self.num_patches,
|
||||||
|
'projection': self.projection,
|
||||||
|
'position_embedding': self.position_embedding,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue