mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-12-23 03:24:12 +01:00
wip
This commit is contained in:
parent
4651000191
commit
9ccc495b4a
4 changed files with 505 additions and 755 deletions
|
|
@ -147,6 +147,7 @@ def generalized_dice_loss(y_true, y_pred):
|
||||||
return 1 - generalized_dice_coeff2(y_true, y_pred)
|
return 1 - generalized_dice_coeff2(y_true, y_pred)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: document where this is from
|
||||||
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
||||||
"""
|
"""
|
||||||
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
||||||
|
|
@ -175,6 +176,7 @@ def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
||||||
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: document where this is from
|
||||||
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
|
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
|
||||||
verbose=False):
|
verbose=False):
|
||||||
"""
|
"""
|
||||||
|
|
@ -267,6 +269,8 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=T
|
||||||
return K.mean(non_zero_sum / non_zero_count)
|
return K.mean(non_zero_sum / non_zero_count)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: document where this is from
|
||||||
|
# TODO: Why a different implementation than IoU from utils?
|
||||||
def mean_iou(y_true, y_pred, **kwargs):
|
def mean_iou(y_true, y_pred, **kwargs):
|
||||||
"""
|
"""
|
||||||
Compute mean Intersection over Union of two segmentation masks, via Keras.
|
Compute mean Intersection over Union of two segmentation masks, via Keras.
|
||||||
|
|
@ -311,6 +315,7 @@ def iou_vahid(y_true, y_pred):
|
||||||
return K.mean(iou)
|
return K.mean(iou)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: copy from utils?
|
||||||
def IoU_metric(Yi, y_predi):
|
def IoU_metric(Yi, y_predi):
|
||||||
# mean Intersection over Union
|
# mean Intersection over Union
|
||||||
# Mean IoU = TP/(FN + TP + FP)
|
# Mean IoU = TP/(FN + TP + FP)
|
||||||
|
|
@ -337,6 +342,7 @@ def IoU_metric_keras(y_true, y_pred):
|
||||||
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
|
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: unused, remove?
|
||||||
def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
||||||
"""
|
"""
|
||||||
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
|
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ from tensorflow.keras.layers import *
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
from tensorflow.keras.regularizers import l2
|
from tensorflow.keras.regularizers import l2
|
||||||
|
|
||||||
|
from eynollah.patch_encoder import Patches, PatchEncoder
|
||||||
|
|
||||||
##mlp_head_units = [512, 256]#[2048, 1024]
|
##mlp_head_units = [512, 256]#[2048, 1024]
|
||||||
###projection_dim = 64
|
###projection_dim = 64
|
||||||
##transformer_layers = 2#8
|
##transformer_layers = 2#8
|
||||||
|
|
@ -38,87 +40,6 @@ def mlp(x, hidden_units, dropout_rate):
|
||||||
x = layers.Dropout(dropout_rate)(x)
|
x = layers.Dropout(dropout_rate)(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
|
||||||
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
|
||||||
super(Patches, self).__init__()
|
|
||||||
self.patch_size_x = patch_size_x
|
|
||||||
self.patch_size_y = patch_size_y
|
|
||||||
|
|
||||||
def call(self, images):
|
|
||||||
#print(tf.shape(images)[1],'images')
|
|
||||||
#print(self.patch_size,'self.patch_size')
|
|
||||||
batch_size = tf.shape(images)[0]
|
|
||||||
patches = tf.image.extract_patches(
|
|
||||||
images=images,
|
|
||||||
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
|
||||||
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
|
||||||
rates=[1, 1, 1, 1],
|
|
||||||
padding="VALID",
|
|
||||||
)
|
|
||||||
#patch_dims = patches.shape[-1]
|
|
||||||
patch_dims = tf.shape(patches)[-1]
|
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
|
||||||
return patches
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'patch_size_x': self.patch_size_x,
|
|
||||||
'patch_size_y': self.patch_size_y,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
class Patches_old(layers.Layer):
|
|
||||||
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
|
||||||
super(Patches, self).__init__()
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
def call(self, images):
|
|
||||||
#print(tf.shape(images)[1],'images')
|
|
||||||
#print(self.patch_size,'self.patch_size')
|
|
||||||
batch_size = tf.shape(images)[0]
|
|
||||||
patches = tf.image.extract_patches(
|
|
||||||
images=images,
|
|
||||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
|
||||||
strides=[1, self.patch_size, self.patch_size, 1],
|
|
||||||
rates=[1, 1, 1, 1],
|
|
||||||
padding="VALID",
|
|
||||||
)
|
|
||||||
patch_dims = patches.shape[-1]
|
|
||||||
#print(patches.shape,patch_dims,'patch_dims')
|
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
|
||||||
return patches
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'patch_size': self.patch_size,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
|
||||||
def __init__(self, num_patches, projection_dim):
|
|
||||||
super(PatchEncoder, self).__init__()
|
|
||||||
self.num_patches = num_patches
|
|
||||||
self.projection = layers.Dense(units=projection_dim)
|
|
||||||
self.position_embedding = layers.Embedding(
|
|
||||||
input_dim=num_patches, output_dim=projection_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, patch):
|
|
||||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
|
||||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
|
||||||
return encoded
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'num_patches': self.num_patches,
|
|
||||||
'projection': self.projection,
|
|
||||||
'position_embedding': self.position_embedding,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def one_side_pad(x):
|
def one_side_pad(x):
|
||||||
|
|
|
||||||
|
|
@ -175,22 +175,94 @@ def config_params():
|
||||||
characters_txt_file = None # Directory of characters text file needed for cnn_rnn_ocr model training. The file ends with .txt
|
characters_txt_file = None # Directory of characters text file needed for cnn_rnn_ocr model training. The file ends with .txt
|
||||||
|
|
||||||
@ex.automain
|
@ex.automain
|
||||||
def run(_config, n_classes, n_epochs, input_height,
|
def run(
|
||||||
input_width, weight_decay, weighted_loss,
|
_config,
|
||||||
index_start, dir_of_start_model, is_loss_soft_dice,
|
n_classes,
|
||||||
n_batch, patches, augmentation, flip_aug,
|
n_epochs,
|
||||||
blur_aug, padding_white, padding_black, scaling, shifting, degrading,channels_shuffling,
|
input_height,
|
||||||
brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes,
|
input_width,
|
||||||
brightness, dir_train, data_is_provided, scaling_bluring,
|
weight_decay,
|
||||||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
weighted_loss,
|
||||||
thetha, thetha_padd, scaling_flip, continue_training, transformer_projection_dim,
|
index_start,
|
||||||
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
|
dir_of_start_model,
|
||||||
transformer_patchsize_x, transformer_patchsize_y,
|
is_loss_soft_dice,
|
||||||
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
n_batch,
|
||||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds,
|
patches,
|
||||||
dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin,
|
augmentation,
|
||||||
textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin,
|
flip_aug,
|
||||||
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors, pepper_indexes, white_padds, skewing_amplitudes, max_len):
|
blur_aug,
|
||||||
|
padding_white,
|
||||||
|
padding_black,
|
||||||
|
scaling,
|
||||||
|
shifting,
|
||||||
|
degrading,
|
||||||
|
channels_shuffling,
|
||||||
|
brightening,
|
||||||
|
binarization,
|
||||||
|
adding_rgb_background,
|
||||||
|
adding_rgb_foreground,
|
||||||
|
add_red_textlines,
|
||||||
|
blur_k,
|
||||||
|
scales,
|
||||||
|
degrade_scales,
|
||||||
|
shuffle_indexes,
|
||||||
|
brightness,
|
||||||
|
dir_train,
|
||||||
|
data_is_provided,
|
||||||
|
scaling_bluring,
|
||||||
|
scaling_brightness,
|
||||||
|
scaling_binarization,
|
||||||
|
rotation,
|
||||||
|
rotation_not_90,
|
||||||
|
thetha,
|
||||||
|
thetha_padd,
|
||||||
|
scaling_flip,
|
||||||
|
continue_training,
|
||||||
|
transformer_projection_dim,
|
||||||
|
transformer_mlp_head_units,
|
||||||
|
transformer_layers,
|
||||||
|
transformer_num_heads,
|
||||||
|
transformer_cnn_first,
|
||||||
|
transformer_patchsize_x,
|
||||||
|
transformer_patchsize_y,
|
||||||
|
transformer_num_patches_xy,
|
||||||
|
backbone_type,
|
||||||
|
save_interval,
|
||||||
|
flip_index,
|
||||||
|
dir_eval,
|
||||||
|
dir_output,
|
||||||
|
pretraining,
|
||||||
|
learning_rate,
|
||||||
|
task,
|
||||||
|
f1_threshold_classification,
|
||||||
|
classification_classes_name,
|
||||||
|
dir_img_bin,
|
||||||
|
number_of_backgrounds_per_image,
|
||||||
|
dir_rgb_backgrounds,
|
||||||
|
dir_rgb_foregrounds,
|
||||||
|
characters_txt_file,
|
||||||
|
color_padding_rotation,
|
||||||
|
bin_deg,
|
||||||
|
image_inversion,
|
||||||
|
white_noise_strap,
|
||||||
|
textline_skewing,
|
||||||
|
textline_skewing_bin,
|
||||||
|
textline_left_in_depth,
|
||||||
|
textline_left_in_depth_bin,
|
||||||
|
textline_right_in_depth,
|
||||||
|
textline_right_in_depth_bin,
|
||||||
|
textline_up_in_depth,
|
||||||
|
textline_up_in_depth_bin,
|
||||||
|
textline_down_in_depth,
|
||||||
|
textline_down_in_depth_bin,
|
||||||
|
pepper_bin_aug,
|
||||||
|
pepper_aug,
|
||||||
|
padd_colors,
|
||||||
|
pepper_indexes,
|
||||||
|
white_padds,
|
||||||
|
skewing_amplitudes,
|
||||||
|
max_len,
|
||||||
|
):
|
||||||
|
|
||||||
if dir_rgb_backgrounds:
|
if dir_rgb_backgrounds:
|
||||||
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
||||||
|
|
@ -201,6 +273,10 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
list_all_possible_foreground_rgbs = os.listdir(dir_rgb_foregrounds)
|
list_all_possible_foreground_rgbs = os.listdir(dir_rgb_foregrounds)
|
||||||
else:
|
else:
|
||||||
list_all_possible_foreground_rgbs = None
|
list_all_possible_foreground_rgbs = None
|
||||||
|
|
||||||
|
dir_seg = None
|
||||||
|
weights = None
|
||||||
|
model = None
|
||||||
|
|
||||||
if task == "segmentation" or task == "enhancement" or task == "binarization":
|
if task == "segmentation" or task == "enhancement" or task == "binarization":
|
||||||
if data_is_provided:
|
if data_is_provided:
|
||||||
|
|
@ -285,6 +361,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
assert dir_seg is not None
|
||||||
for obj in os.listdir(dir_seg):
|
for obj in os.listdir(dir_seg):
|
||||||
try:
|
try:
|
||||||
label_obj = cv2.imread(dir_seg + '/' + obj)
|
label_obj = cv2.imread(dir_seg + '/' + obj)
|
||||||
|
|
@ -314,6 +391,8 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||||
if not is_loss_soft_dice and not weighted_loss:
|
if not is_loss_soft_dice and not weighted_loss:
|
||||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||||
|
else:
|
||||||
|
raise ValueError("backbone_type must be 'nontransformer' or 'transformer'")
|
||||||
else:
|
else:
|
||||||
index_start = 0
|
index_start = 0
|
||||||
if backbone_type=='nontransformer':
|
if backbone_type=='nontransformer':
|
||||||
|
|
@ -348,6 +427,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
|
model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
|
||||||
|
|
||||||
|
assert model is not None
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
|
|
@ -377,9 +457,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
##score_best=[]
|
##score_best=[]
|
||||||
##score_best.append(0)
|
##score_best.append(0)
|
||||||
|
|
||||||
if save_interval:
|
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
||||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
|
||||||
|
|
||||||
|
|
||||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
@ -459,8 +537,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
||||||
model.compile(optimizer=opt)
|
model.compile(optimizer=opt)
|
||||||
|
|
||||||
if save_interval:
|
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
||||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
|
||||||
|
|
||||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
@ -559,8 +636,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
||||||
model.compile(loss="binary_crossentropy",
|
model.compile(loss="binary_crossentropy",
|
||||||
optimizer = opt_adam,metrics=['accuracy'])
|
optimizer = opt_adam,metrics=['accuracy'])
|
||||||
|
|
||||||
if save_interval:
|
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
||||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
|
||||||
|
|
||||||
for i in range(n_epochs):
|
for i in range(n_epochs):
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue