mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-12-23 03:24:12 +01:00
wip
This commit is contained in:
parent
4651000191
commit
9ccc495b4a
4 changed files with 505 additions and 755 deletions
|
|
@ -147,6 +147,7 @@ def generalized_dice_loss(y_true, y_pred):
|
|||
return 1 - generalized_dice_coeff2(y_true, y_pred)
|
||||
|
||||
|
||||
# TODO: document where this is from
|
||||
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
||||
"""
|
||||
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
|
||||
|
|
@ -175,6 +176,7 @@ def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
|
|||
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||
|
||||
|
||||
# TODO: document where this is from
|
||||
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
|
||||
verbose=False):
|
||||
"""
|
||||
|
|
@ -267,6 +269,8 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=T
|
|||
return K.mean(non_zero_sum / non_zero_count)
|
||||
|
||||
|
||||
# TODO: document where this is from
|
||||
# TODO: Why a different implementation than IoU from utils?
|
||||
def mean_iou(y_true, y_pred, **kwargs):
|
||||
"""
|
||||
Compute mean Intersection over Union of two segmentation masks, via Keras.
|
||||
|
|
@ -311,6 +315,7 @@ def iou_vahid(y_true, y_pred):
|
|||
return K.mean(iou)
|
||||
|
||||
|
||||
# TODO: copy from utils?
|
||||
def IoU_metric(Yi, y_predi):
|
||||
# mean Intersection over Union
|
||||
# Mean IoU = TP/(FN + TP + FP)
|
||||
|
|
@ -337,6 +342,7 @@ def IoU_metric_keras(y_true, y_pred):
|
|||
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
|
||||
|
||||
|
||||
# TODO: unused, remove?
|
||||
def jaccard_distance_loss(y_true, y_pred, smooth=100):
|
||||
"""
|
||||
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from tensorflow.keras.layers import *
|
|||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.regularizers import l2
|
||||
|
||||
from eynollah.patch_encoder import Patches, PatchEncoder
|
||||
|
||||
##mlp_head_units = [512, 256]#[2048, 1024]
|
||||
###projection_dim = 64
|
||||
##transformer_layers = 2#8
|
||||
|
|
@ -38,87 +40,6 @@ def mlp(x, hidden_units, dropout_rate):
|
|||
x = layers.Dropout(dropout_rate)(x)
|
||||
return x
|
||||
|
||||
class Patches(layers.Layer):
|
||||
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size_x = patch_size_x
|
||||
self.patch_size_y = patch_size_y
|
||||
|
||||
def call(self, images):
|
||||
#print(tf.shape(images)[1],'images')
|
||||
#print(self.patch_size,'self.patch_size')
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
#patch_dims = patches.shape[-1]
|
||||
patch_dims = tf.shape(patches)[-1]
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size_x': self.patch_size_x,
|
||||
'patch_size_y': self.patch_size_y,
|
||||
})
|
||||
return config
|
||||
|
||||
class Patches_old(layers.Layer):
|
||||
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
|
||||
super(Patches, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
def call(self, images):
|
||||
#print(tf.shape(images)[1],'images')
|
||||
#print(self.patch_size,'self.patch_size')
|
||||
batch_size = tf.shape(images)[0]
|
||||
patches = tf.image.extract_patches(
|
||||
images=images,
|
||||
sizes=[1, self.patch_size, self.patch_size, 1],
|
||||
strides=[1, self.patch_size, self.patch_size, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
)
|
||||
patch_dims = patches.shape[-1]
|
||||
#print(patches.shape,patch_dims,'patch_dims')
|
||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||
return patches
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'patch_size': self.patch_size,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
class PatchEncoder(layers.Layer):
|
||||
def __init__(self, num_patches, projection_dim):
|
||||
super(PatchEncoder, self).__init__()
|
||||
self.num_patches = num_patches
|
||||
self.projection = layers.Dense(units=projection_dim)
|
||||
self.position_embedding = layers.Embedding(
|
||||
input_dim=num_patches, output_dim=projection_dim
|
||||
)
|
||||
|
||||
def call(self, patch):
|
||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||
return encoded
|
||||
def get_config(self):
|
||||
|
||||
config = super().get_config().copy()
|
||||
config.update({
|
||||
'num_patches': self.num_patches,
|
||||
'projection': self.projection,
|
||||
'position_embedding': self.position_embedding,
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
def one_side_pad(x):
|
||||
|
|
|
|||
|
|
@ -175,22 +175,94 @@ def config_params():
|
|||
characters_txt_file = None # Directory of characters text file needed for cnn_rnn_ocr model training. The file ends with .txt
|
||||
|
||||
@ex.automain
|
||||
def run(_config, n_classes, n_epochs, input_height,
|
||||
input_width, weight_decay, weighted_loss,
|
||||
index_start, dir_of_start_model, is_loss_soft_dice,
|
||||
n_batch, patches, augmentation, flip_aug,
|
||||
blur_aug, padding_white, padding_black, scaling, shifting, degrading,channels_shuffling,
|
||||
brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes,
|
||||
brightness, dir_train, data_is_provided, scaling_bluring,
|
||||
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
|
||||
thetha, thetha_padd, scaling_flip, continue_training, transformer_projection_dim,
|
||||
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
|
||||
transformer_patchsize_x, transformer_patchsize_y,
|
||||
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds,
|
||||
dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin,
|
||||
textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin,
|
||||
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors, pepper_indexes, white_padds, skewing_amplitudes, max_len):
|
||||
def run(
|
||||
_config,
|
||||
n_classes,
|
||||
n_epochs,
|
||||
input_height,
|
||||
input_width,
|
||||
weight_decay,
|
||||
weighted_loss,
|
||||
index_start,
|
||||
dir_of_start_model,
|
||||
is_loss_soft_dice,
|
||||
n_batch,
|
||||
patches,
|
||||
augmentation,
|
||||
flip_aug,
|
||||
blur_aug,
|
||||
padding_white,
|
||||
padding_black,
|
||||
scaling,
|
||||
shifting,
|
||||
degrading,
|
||||
channels_shuffling,
|
||||
brightening,
|
||||
binarization,
|
||||
adding_rgb_background,
|
||||
adding_rgb_foreground,
|
||||
add_red_textlines,
|
||||
blur_k,
|
||||
scales,
|
||||
degrade_scales,
|
||||
shuffle_indexes,
|
||||
brightness,
|
||||
dir_train,
|
||||
data_is_provided,
|
||||
scaling_bluring,
|
||||
scaling_brightness,
|
||||
scaling_binarization,
|
||||
rotation,
|
||||
rotation_not_90,
|
||||
thetha,
|
||||
thetha_padd,
|
||||
scaling_flip,
|
||||
continue_training,
|
||||
transformer_projection_dim,
|
||||
transformer_mlp_head_units,
|
||||
transformer_layers,
|
||||
transformer_num_heads,
|
||||
transformer_cnn_first,
|
||||
transformer_patchsize_x,
|
||||
transformer_patchsize_y,
|
||||
transformer_num_patches_xy,
|
||||
backbone_type,
|
||||
save_interval,
|
||||
flip_index,
|
||||
dir_eval,
|
||||
dir_output,
|
||||
pretraining,
|
||||
learning_rate,
|
||||
task,
|
||||
f1_threshold_classification,
|
||||
classification_classes_name,
|
||||
dir_img_bin,
|
||||
number_of_backgrounds_per_image,
|
||||
dir_rgb_backgrounds,
|
||||
dir_rgb_foregrounds,
|
||||
characters_txt_file,
|
||||
color_padding_rotation,
|
||||
bin_deg,
|
||||
image_inversion,
|
||||
white_noise_strap,
|
||||
textline_skewing,
|
||||
textline_skewing_bin,
|
||||
textline_left_in_depth,
|
||||
textline_left_in_depth_bin,
|
||||
textline_right_in_depth,
|
||||
textline_right_in_depth_bin,
|
||||
textline_up_in_depth,
|
||||
textline_up_in_depth_bin,
|
||||
textline_down_in_depth,
|
||||
textline_down_in_depth_bin,
|
||||
pepper_bin_aug,
|
||||
pepper_aug,
|
||||
padd_colors,
|
||||
pepper_indexes,
|
||||
white_padds,
|
||||
skewing_amplitudes,
|
||||
max_len,
|
||||
):
|
||||
|
||||
if dir_rgb_backgrounds:
|
||||
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
|
||||
|
|
@ -202,6 +274,10 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
else:
|
||||
list_all_possible_foreground_rgbs = None
|
||||
|
||||
dir_seg = None
|
||||
weights = None
|
||||
model = None
|
||||
|
||||
if task == "segmentation" or task == "enhancement" or task == "binarization":
|
||||
if data_is_provided:
|
||||
dir_train_flowing = os.path.join(dir_output, 'train')
|
||||
|
|
@ -285,6 +361,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
pass
|
||||
else:
|
||||
|
||||
assert dir_seg is not None
|
||||
for obj in os.listdir(dir_seg):
|
||||
try:
|
||||
label_obj = cv2.imread(dir_seg + '/' + obj)
|
||||
|
|
@ -314,6 +391,8 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
|
||||
if not is_loss_soft_dice and not weighted_loss:
|
||||
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
else:
|
||||
raise ValueError("backbone_type must be 'nontransformer' or 'transformer'")
|
||||
else:
|
||||
index_start = 0
|
||||
if backbone_type=='nontransformer':
|
||||
|
|
@ -348,6 +427,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
sys.exit(1)
|
||||
model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
|
||||
|
||||
assert model is not None
|
||||
#if you want to see the model structure just uncomment model summary.
|
||||
model.summary()
|
||||
|
||||
|
|
@ -377,9 +457,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
|
||||
if save_interval:
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
||||
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
||||
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
if save_interval:
|
||||
|
|
@ -459,8 +537,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
|
||||
model.compile(optimizer=opt)
|
||||
|
||||
if save_interval:
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
||||
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
if save_interval:
|
||||
|
|
@ -559,8 +636,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
model.compile(loss="binary_crossentropy",
|
||||
optimizer = opt_adam,metrics=['accuracy'])
|
||||
|
||||
if save_interval:
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
|
||||
|
||||
for i in range(n_epochs):
|
||||
if save_interval:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue