mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-27 07:44:12 +01:00
Merge 557fb227f3 into 38c028c6b5
This commit is contained in:
commit
0aebf3a24d
4 changed files with 97 additions and 68 deletions
|
|
@ -58,8 +58,6 @@ source = ["eynollah"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
# TODO: Reenable and fix after release v0.6.0
|
|
||||||
exclude = ['src/eynollah/training']
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
ignore = [
|
ignore = [
|
||||||
|
|
|
||||||
|
|
@ -252,6 +252,7 @@ def get_textline_contours_for_visualization(xml_file):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
x_len, y_len = 0, 0
|
||||||
for jj in root1.iter(link+'Page'):
|
for jj in root1.iter(link+'Page'):
|
||||||
y_len=int(jj.attrib['imageHeight'])
|
y_len=int(jj.attrib['imageHeight'])
|
||||||
x_len=int(jj.attrib['imageWidth'])
|
x_len=int(jj.attrib['imageWidth'])
|
||||||
|
|
@ -293,6 +294,7 @@ def get_textline_contours_and_ocr_text(xml_file):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
x_len, y_len = 0, 0
|
||||||
for jj in root1.iter(link+'Page'):
|
for jj in root1.iter(link+'Page'):
|
||||||
y_len=int(jj.attrib['imageHeight'])
|
y_len=int(jj.attrib['imageHeight'])
|
||||||
x_len=int(jj.attrib['imageWidth'])
|
x_len=int(jj.attrib['imageWidth'])
|
||||||
|
|
@ -362,7 +364,7 @@ def get_layout_contours_for_visualization(xml_file):
|
||||||
link=alltags[0].split('}')[0]+'}'
|
link=alltags[0].split('}')[0]+'}'
|
||||||
|
|
||||||
|
|
||||||
|
x_len, y_len = 0, 0
|
||||||
for jj in root1.iter(link+'Page'):
|
for jj in root1.iter(link+'Page'):
|
||||||
y_len=int(jj.attrib['imageHeight'])
|
y_len=int(jj.attrib['imageHeight'])
|
||||||
x_len=int(jj.attrib['imageWidth'])
|
x_len=int(jj.attrib['imageWidth'])
|
||||||
|
|
@ -637,7 +639,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
link=alltags[0].split('}')[0]+'}'
|
link=alltags[0].split('}')[0]+'}'
|
||||||
|
|
||||||
|
|
||||||
|
x_len, y_len = 0, 0
|
||||||
for jj in root1.iter(link+'Page'):
|
for jj in root1.iter(link+'Page'):
|
||||||
y_len=int(jj.attrib['imageHeight'])
|
y_len=int(jj.attrib['imageHeight'])
|
||||||
x_len=int(jj.attrib['imageWidth'])
|
x_len=int(jj.attrib['imageWidth'])
|
||||||
|
|
@ -645,15 +647,12 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
if 'columns_width' in list(config_params.keys()):
|
if 'columns_width' in list(config_params.keys()):
|
||||||
columns_width_dict = config_params['columns_width']
|
columns_width_dict = config_params['columns_width']
|
||||||
metadata_element = root1.find(link+'Metadata')
|
metadata_element = root1.find(link+'Metadata')
|
||||||
comment_is_sub_element = False
|
num_col = None
|
||||||
for child in metadata_element:
|
for child in metadata_element:
|
||||||
tag2 = child.tag
|
tag2 = child.tag
|
||||||
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
|
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
|
||||||
text_comments = child.text
|
text_comments = child.text
|
||||||
num_col = int(text_comments.split('num_col')[1])
|
num_col = int(text_comments.split('num_col')[1])
|
||||||
comment_is_sub_element = True
|
|
||||||
if not comment_is_sub_element:
|
|
||||||
num_col = None
|
|
||||||
|
|
||||||
if num_col:
|
if num_col:
|
||||||
x_new = columns_width_dict[str(num_col)]
|
x_new = columns_width_dict[str(num_col)]
|
||||||
|
|
@ -1739,15 +1738,15 @@ tot_region_ref,x_len, y_len,index_tot_regions, img_poly
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def bounding_box(cnt,color, corr_order_index ):
|
# def bounding_box(cnt,color, corr_order_index ):
|
||||||
x, y, w, h = cv2.boundingRect(cnt)
|
# x, y, w, h = cv2.boundingRect(cnt)
|
||||||
x = int(x*scale_w)
|
# x = int(x*scale_w)
|
||||||
y = int(y*scale_h)
|
# y = int(y*scale_h)
|
||||||
|
#
|
||||||
w = int(w*scale_w)
|
# w = int(w*scale_w)
|
||||||
h = int(h*scale_h)
|
# h = int(h*scale_h)
|
||||||
|
#
|
||||||
return [x,y,w,h,int(color), int(corr_order_index)+1]
|
# return [x,y,w,h,int(color), int(corr_order_index)+1]
|
||||||
|
|
||||||
def resize_image(seg_in,input_height,input_width):
|
def resize_image(seg_in,input_height,input_width):
|
||||||
return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST)
|
return cv2.resize(seg_in,(input_width,input_height),interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
from typing import Tuple
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from tensorflow.keras.models import load_model
|
from numpy._typing import NDArray
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras import backend as K
|
from keras.models import Model, load_model
|
||||||
from tensorflow.keras.layers import *
|
from keras import backend as K
|
||||||
import click
|
import click
|
||||||
from tensorflow.python.keras import backend as tensorflow_backend
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
@ -34,6 +35,7 @@ Tool to load model and predict for given image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class sbb_predict:
|
class sbb_predict:
|
||||||
|
|
||||||
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||||
self.image=image
|
self.image=image
|
||||||
self.dir_in=dir_in
|
self.dir_in=dir_in
|
||||||
|
|
@ -77,7 +79,7 @@ class sbb_predict:
|
||||||
#print(img[:,:,0].min())
|
#print(img[:,:,0].min())
|
||||||
#blur = cv2.GaussianBlur(img,(5,5))
|
#blur = cv2.GaussianBlur(img,(5,5))
|
||||||
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
_, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -116,19 +118,19 @@ class sbb_predict:
|
||||||
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
||||||
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||||
|
|
||||||
def weighted_categorical_crossentropy(self,weights=None):
|
# def weighted_categorical_crossentropy(self,weights=None):
|
||||||
|
#
|
||||||
def loss(y_true, y_pred):
|
# def loss(y_true, y_pred):
|
||||||
labels_floats = tf.cast(y_true, tf.float32)
|
# labels_floats = tf.cast(y_true, tf.float32)
|
||||||
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
# per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
||||||
|
#
|
||||||
if weights is not None:
|
# if weights is not None:
|
||||||
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
# weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
||||||
np.array(weights, dtype=np.float32)[None, None, None])
|
# np.array(weights, dtype=np.float32)[None, None, None])
|
||||||
* labels_floats, axis=-1), 1.0)
|
# * labels_floats, axis=-1), 1.0)
|
||||||
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
# per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||||
return tf.reduce_mean(per_pixel_loss)
|
# return tf.reduce_mean(per_pixel_loss)
|
||||||
return self.loss
|
# return self.loss
|
||||||
|
|
||||||
|
|
||||||
def IoU(self,Yi,y_predi):
|
def IoU(self,Yi,y_predi):
|
||||||
|
|
@ -177,12 +179,13 @@ class sbb_predict:
|
||||||
##if self.weights_dir!=None:
|
##if self.weights_dir!=None:
|
||||||
##self.model.load_weights(self.weights_dir)
|
##self.model.load_weights(self.weights_dir)
|
||||||
|
|
||||||
|
assert isinstance(self.model, Model)
|
||||||
if self.task != 'classification' and self.task != 'reading_order':
|
if self.task != 'classification' and self.task != 'reading_order':
|
||||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||||
|
|
||||||
def visualize_model_output(self, prediction, img, task):
|
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
|
||||||
if task == "binarization":
|
if task == "binarization":
|
||||||
prediction = prediction * -1
|
prediction = prediction * -1
|
||||||
prediction = prediction + 1
|
prediction = prediction + 1
|
||||||
|
|
@ -226,9 +229,12 @@ class sbb_predict:
|
||||||
|
|
||||||
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
|
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
|
||||||
|
|
||||||
|
assert isinstance(added_image, np.ndarray)
|
||||||
|
assert isinstance(layout_only, np.ndarray)
|
||||||
return added_image, layout_only
|
return added_image, layout_only
|
||||||
|
|
||||||
def predict(self, image_dir):
|
def predict(self, image_dir):
|
||||||
|
assert isinstance(self.model, Model)
|
||||||
if self.task == 'classification':
|
if self.task == 'classification':
|
||||||
classes_names = self.config_params_model['classification_classes_name']
|
classes_names = self.config_params_model['classification_classes_name']
|
||||||
img_1ch = img=cv2.imread(image_dir, 0)
|
img_1ch = img=cv2.imread(image_dir, 0)
|
||||||
|
|
@ -240,7 +246,7 @@ class sbb_predict:
|
||||||
img_in[0, :, :, 1] = img_1ch[:, :]
|
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||||
|
|
||||||
label_p_pred = self.model.predict(img_in, verbose=0)
|
label_p_pred = self.model.predict(img_in, verbose='0')
|
||||||
index_class = np.argmax(label_p_pred[0])
|
index_class = np.argmax(label_p_pred[0])
|
||||||
|
|
||||||
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||||
|
|
@ -361,7 +367,7 @@ class sbb_predict:
|
||||||
#input_1[:,:,1] = img3[:,:,0]/5.
|
#input_1[:,:,1] = img3[:,:,0]/5.
|
||||||
|
|
||||||
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
|
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
|
||||||
y_pr = self.model.predict(input_1 , verbose=0)
|
y_pr = self.model.predict(input_1 , verbose='0')
|
||||||
scalibility_num = scalibility_num+1
|
scalibility_num = scalibility_num+1
|
||||||
|
|
||||||
if batch_counter==inference_bs:
|
if batch_counter==inference_bs:
|
||||||
|
|
@ -395,6 +401,7 @@ class sbb_predict:
|
||||||
name_space = name_space.split('{')[1]
|
name_space = name_space.split('{')[1]
|
||||||
|
|
||||||
page_element = root_xml.find(link+'Page')
|
page_element = root_xml.find(link+'Page')
|
||||||
|
assert isinstance(page_element, ET.Element)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
||||||
|
|
@ -489,7 +496,7 @@ class sbb_predict:
|
||||||
|
|
||||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||||
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
|
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
|
||||||
verbose=0)
|
verbose='0')
|
||||||
|
|
||||||
if self.task == 'enhancement':
|
if self.task == 'enhancement':
|
||||||
seg = label_p_pred[0, :, :, :]
|
seg = label_p_pred[0, :, :, :]
|
||||||
|
|
@ -497,6 +504,8 @@ class sbb_predict:
|
||||||
elif self.task == 'segmentation' or self.task == 'binarization':
|
elif self.task == 'segmentation' or self.task == 'binarization':
|
||||||
seg = np.argmax(label_p_pred, axis=3)[0]
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||||
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unhandled task {self.task}")
|
||||||
|
|
||||||
|
|
||||||
if i == 0 and j == 0:
|
if i == 0 and j == 0:
|
||||||
|
|
@ -551,6 +560,8 @@ class sbb_predict:
|
||||||
elif self.task == 'segmentation' or self.task == 'binarization':
|
elif self.task == 'segmentation' or self.task == 'binarization':
|
||||||
seg = np.argmax(label_p_pred, axis=3)[0]
|
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||||
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unhandled task {self.task}")
|
||||||
|
|
||||||
prediction_true = seg.astype(int)
|
prediction_true = seg.astype(int)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,29 @@
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from tensorflow.keras.models import *
|
from keras.layers import (
|
||||||
from tensorflow.keras.layers import *
|
Activation,
|
||||||
from tensorflow.keras import layers
|
Add,
|
||||||
from tensorflow.keras.regularizers import l2
|
AveragePooling2D,
|
||||||
|
BatchNormalization,
|
||||||
|
Conv2D,
|
||||||
|
Dense,
|
||||||
|
Dropout,
|
||||||
|
Embedding,
|
||||||
|
Flatten,
|
||||||
|
Input,
|
||||||
|
Lambda,
|
||||||
|
Layer,
|
||||||
|
LayerNormalization,
|
||||||
|
MaxPooling2D,
|
||||||
|
MultiHeadAttention,
|
||||||
|
UpSampling2D,
|
||||||
|
ZeroPadding2D,
|
||||||
|
add,
|
||||||
|
concatenate
|
||||||
|
)
|
||||||
|
from keras.models import Model
|
||||||
|
import tensorflow as tf
|
||||||
|
# from keras import layers, models
|
||||||
|
from keras.regularizers import l2
|
||||||
|
|
||||||
##mlp_head_units = [512, 256]#[2048, 1024]
|
##mlp_head_units = [512, 256]#[2048, 1024]
|
||||||
###projection_dim = 64
|
###projection_dim = 64
|
||||||
|
|
@ -15,13 +35,13 @@ MERGE_AXIS = -1
|
||||||
|
|
||||||
def mlp(x, hidden_units, dropout_rate):
|
def mlp(x, hidden_units, dropout_rate):
|
||||||
for units in hidden_units:
|
for units in hidden_units:
|
||||||
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
x = Dense(units, activation=tf.nn.gelu)(x)
|
||||||
x = layers.Dropout(dropout_rate)(x)
|
x = Dropout(dropout_rate)(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
class Patches(Layer):
|
||||||
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||||
super(Patches, self).__init__()
|
super().__init__()
|
||||||
self.patch_size_x = patch_size_x
|
self.patch_size_x = patch_size_x
|
||||||
self.patch_size_y = patch_size_y
|
self.patch_size_y = patch_size_y
|
||||||
|
|
||||||
|
|
@ -49,9 +69,9 @@ class Patches(layers.Layer):
|
||||||
})
|
})
|
||||||
return config
|
return config
|
||||||
|
|
||||||
class Patches_old(layers.Layer):
|
class Patches_old(Layer):
|
||||||
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||||
super(Patches, self).__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
|
||||||
def call(self, images):
|
def call(self, images):
|
||||||
|
|
@ -69,8 +89,8 @@ class Patches_old(layers.Layer):
|
||||||
#print(patches.shape,patch_dims,'patch_dims')
|
#print(patches.shape,patch_dims,'patch_dims')
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
return patches
|
return patches
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
config = super().get_config().copy()
|
config = super().get_config().copy()
|
||||||
config.update({
|
config.update({
|
||||||
'patch_size': self.patch_size,
|
'patch_size': self.patch_size,
|
||||||
|
|
@ -78,12 +98,12 @@ class Patches_old(layers.Layer):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
class PatchEncoder(Layer):
|
||||||
def __init__(self, num_patches, projection_dim):
|
def __init__(self, num_patches, projection_dim):
|
||||||
super(PatchEncoder, self).__init__()
|
super(PatchEncoder, self).__init__()
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.projection = layers.Dense(units=projection_dim)
|
self.projection = Dense(units=projection_dim)
|
||||||
self.position_embedding = layers.Embedding(
|
self.position_embedding = Embedding(
|
||||||
input_dim=num_patches, output_dim=projection_dim
|
input_dim=num_patches, output_dim=projection_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -144,7 +164,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
|
||||||
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
|
||||||
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
|
||||||
|
|
||||||
x = layers.add([x, input_tensor])
|
x = add([x, input_tensor])
|
||||||
x = Activation('relu')(x)
|
x = Activation('relu')(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
@ -189,12 +209,12 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
|
||||||
name=conv_name_base + '1')(input_tensor)
|
name=conv_name_base + '1')(input_tensor)
|
||||||
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
|
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
|
||||||
|
|
||||||
x = layers.add([x, shortcut])
|
x = add([x, shortcut])
|
||||||
x = Activation('relu')(x)
|
x = Activation('relu')(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def resnet50_unet_light(n_classes, input_height=224, input_width=224, taks="segmentation", weight_decay=1e-6, pretraining=False):
|
def resnet50_unet_light(n_classes, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
assert input_height % 32 == 0
|
assert input_height % 32 == 0
|
||||||
assert input_width % 32 == 0
|
assert input_width % 32 == 0
|
||||||
|
|
||||||
|
|
@ -397,7 +417,7 @@ def resnet50_unet(n_classes, input_height=224, input_width=224, task="segmentati
|
||||||
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
if mlp_head_units is None:
|
if mlp_head_units is None:
|
||||||
mlp_head_units = [128, 64]
|
mlp_head_units = [128, 64]
|
||||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
inputs = Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
#transformer_units = [
|
#transformer_units = [
|
||||||
#projection_dim * 2,
|
#projection_dim * 2,
|
||||||
|
|
@ -452,20 +472,21 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
||||||
|
|
||||||
for _ in range(transformer_layers):
|
for _ in range(transformer_layers):
|
||||||
# Layer normalization 1.
|
# Layer normalization 1.
|
||||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||||
# Create a multi-head attention layer.
|
# Create a multi-head attention layer.
|
||||||
attention_output = layers.MultiHeadAttention(
|
attention_output = MultiHeadAttention(
|
||||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||||
)(x1, x1)
|
)(x1, x1)
|
||||||
# Skip connection 1.
|
# Skip connection 1.
|
||||||
x2 = layers.Add()([attention_output, encoded_patches])
|
x2 = Add()([attention_output, encoded_patches])
|
||||||
# Layer normalization 2.
|
# Layer normalization 2.
|
||||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||||
# MLP.
|
# MLP.
|
||||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||||
# Skip connection 2.
|
# Skip connection 2.
|
||||||
encoded_patches = layers.Add()([x3, x2])
|
encoded_patches = Add()([x3, x2])
|
||||||
|
|
||||||
|
assert isinstance(x, Layer)
|
||||||
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
|
encoded_patches = tf.reshape(encoded_patches, [-1, x.shape[1], x.shape[2] , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||||
|
|
||||||
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
v1024_2048 = Conv2D( 1024 , (1, 1), padding='same', data_format=IMAGE_ORDERING,kernel_regularizer=l2(weight_decay))(encoded_patches)
|
||||||
|
|
@ -521,7 +542,7 @@ def vit_resnet50_unet(n_classes, patch_size_x, patch_size_y, num_patches, mlp_he
|
||||||
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size_y, num_patches, mlp_head_units=None, transformer_layers=8, num_heads =4, projection_dim = 64, input_height=224, input_width=224, task="segmentation", weight_decay=1e-6, pretraining=False):
|
||||||
if mlp_head_units is None:
|
if mlp_head_units is None:
|
||||||
mlp_head_units = [128, 64]
|
mlp_head_units = [128, 64]
|
||||||
inputs = layers.Input(shape=(input_height, input_width, 3))
|
inputs = Input(shape=(input_height, input_width, 3))
|
||||||
|
|
||||||
##transformer_units = [
|
##transformer_units = [
|
||||||
##projection_dim * 2,
|
##projection_dim * 2,
|
||||||
|
|
@ -536,19 +557,19 @@ def vit_resnet50_unet_transformer_before_cnn(n_classes, patch_size_x, patch_size
|
||||||
|
|
||||||
for _ in range(transformer_layers):
|
for _ in range(transformer_layers):
|
||||||
# Layer normalization 1.
|
# Layer normalization 1.
|
||||||
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
|
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
|
||||||
# Create a multi-head attention layer.
|
# Create a multi-head attention layer.
|
||||||
attention_output = layers.MultiHeadAttention(
|
attention_output = MultiHeadAttention(
|
||||||
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
||||||
)(x1, x1)
|
)(x1, x1)
|
||||||
# Skip connection 1.
|
# Skip connection 1.
|
||||||
x2 = layers.Add()([attention_output, encoded_patches])
|
x2 = Add()([attention_output, encoded_patches])
|
||||||
# Layer normalization 2.
|
# Layer normalization 2.
|
||||||
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
x3 = LayerNormalization(epsilon=1e-6)(x2)
|
||||||
# MLP.
|
# MLP.
|
||||||
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
x3 = mlp(x3, hidden_units=mlp_head_units, dropout_rate=0.1)
|
||||||
# Skip connection 2.
|
# Skip connection 2.
|
||||||
encoded_patches = layers.Add()([x3, x2])
|
encoded_patches = Add()([x3, x2])
|
||||||
|
|
||||||
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
|
encoded_patches = tf.reshape(encoded_patches, [-1, input_height, input_width , int( projection_dim / (patch_size_x * patch_size_y) )])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue