mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
training/models.py: make imports explicit
This commit is contained in:
parent
38c028c6b5
commit
3a73ccca2e
2 changed files with 51 additions and 32 deletions
|
|
@ -58,8 +58,6 @@ source = ["eynollah"]
|
|||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
# TODO: Reenable and fix after release v0.6.0
|
||||
exclude = ['src/eynollah/training']
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = [
|
||||
|
|
|
|||
|
|
@ -1,9 +1,29 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras.models import *
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.regularizers import l2
|
||||
from keras.layers import (
|
||||
Activation,
|
||||
Add,
|
||||
AveragePooling2D,
|
||||
BatchNormalization,
|
||||
Conv2D,
|
||||
Dense,
|
||||
Dropout,
|
||||
Embedding,
|
||||
Flatten,
|
||||
Input,
|
||||
Lambda,
|
||||
Layer,
|
||||
LayerNormalization,
|
||||
MaxPooling2D,
|
||||
MultiHeadAttention,
|
||||
UpSampling2D,
|
||||
ZeroPadding2D,
|
||||
add,
|
||||
concatenate
|
||||
)
|
||||
from keras.models import Model
|
||||
import tensorflow as tf
|
||||
# from keras import layers, models
|
||||
from keras.regularizers import l2
|
||||
|
||||
##mlp_head_units = [512, 256]#[2048, 1024]
|
||||
###projection_dim = 64
|
||||
|
|
@ -15,13 +35,13 @@ MERGE_AXIS = -1
|
|||
|
||||
def mlp(x, hidden_units, dropout_rate):
|
||||
for units in hidden_units:
|
||||
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
||||
x = layers.Dropout(dropout_rate)(x)
|
||||
x = Dense(units, activation=tf.nn.gelu)(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
return x
|
||||
|
||||
class Patches(layers.Layer):
|
||||
class Patches(Layer):
|
||||
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
super().__init__()
|
||||
self.patch_size_x = patch_size_x
|
||||
self.patch_size_y = patch_size_y
|
||||
|
||||
|
|
@ -49,9 +69,9 @@ class Patches(layers.Layer):
|
|||
})
|
||||
return config
|
||||
|
||||
class Patches_old(layers.Layer):
|
||||
class Patches_old(Layer):
|
||||
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
|
|
@ -69,8 +89,8 @@ class Patches_old(layers.Layer):
|
|||
#print(patches.shape,patch_dims,'patch_dims')
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
|
|
@ -78,12 +98,12 @@ class Patches_old(layers.Layer):
|
|||
return config
|
||||
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
class PatchEncoder(Layer):
|
||||
def __init__(self, num_patches, projection_dim):
|
||||
super(PatchEncoder, self).__init__()
|
||||
self.num_patches = num_patches
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(
|
||||
self.projection = Dense(units=projection_dim)
|
||||
self.position_embedding = Embedding(
|
||||
input_dim=num_patches, output_dim=projection_dim
|
||||
)
|
||||
|
||||
|
|
@ -144,7 +164,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
|
|||
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
||||
|
||||
x = layers.add([x, input_tensor])
|
||||
x = add([x, input_tensor])
|
||||
x = Activation('relu')(x)
|
||||
return x
|
||||
|
||||
|
|
@ -189,12 +209,12 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
|
|||
name=conv_name_base + '1')(input_tensor)
|
||||
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
|
||||
|
||||
x = layers.add([x, shortcut])
|
||||
x = add([x, shortcut])
|
||||
x = Activation('relu')(x)
|
||||
return x
|
||||
|
||||
|
||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
assert input_height % 32 == 0
|
||||
assert input_width % 32 == 0
|
||||
|
||||
|
|
@ -397,7 +417,7 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
|||
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
if mlp_head_units is None:
|
||||
mlp_head_units = [128, 64]
|
||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||
inputs = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
#transformer_units = [
|
||||
#projection_dim * 2,
|
||||
|
|
@ -452,20 +472,21 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
|||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = layers.MultiHeadAttention(
|
||||
attention_output = MultiHeadAttention(
|
||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = layers.Add()([attention_output, encoded_patches])
|
||||
x2 = Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = layers.Add()([x3, x2])
|
||||
encoded_patches = Add()([x3, x2])
|
||||
|
||||
assert isinstance(x, Layer)
|
||||
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||
|
||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||
|
|
@ -521,7 +542,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
|||
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||
if mlp_head_units is None:
|
||||
mlp_head_units = [128, 64]
|
||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
||||
inputs = Input(shape=(input_height, input_width, 3))
|
||||
|
||||
##transformer_units = [
|
||||
##projection_dim * 2,
|
||||
|
|
@ -536,19 +557,19 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size
|
|||
|
||||
for _ in range(transformer_layers):
|
||||
# Layer normalization 1.
|
||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||
# Create a multi-head attention layer.
|
||||
attention_output = layers.MultiHeadAttention(
|
||||
attention_output = MultiHeadAttention(
|
||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||
)(x1, x1)
|
||||
# Skip connection 1.
|
||||
x2 = layers.Add()([attention_output, encoded_patches])
|
||||
x2 = Add()([attention_output, encoded_patches])
|
||||
# Layer normalization 2.
|
||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
||||
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||
# MLP.
|
||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||
# Skip connection 2.
|
||||
encoded_patches = layers.Add()([x3, x2])
|
||||
encoded_patches = Add()([x3, x2])
|
||||
|
||||
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue