This commit is contained in:
kba 2025-12-19 14:57:10 +01:00
parent 4651000191
commit 9ccc495b4a
4 changed files with 505 additions and 755 deletions

View file

@ -147,6 +147,7 @@ def generalized_dice_loss(y_true, y_pred):
return 1 - generalized_dice_coeff2(y_true, y_pred)
# TODO: document where this is from
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
"""
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
@ -175,6 +176,7 @@ def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
# TODO: document where this is from
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
verbose=False):
"""
@ -267,6 +269,8 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=T
return K.mean(non_zero_sum / non_zero_count)
# TODO: document where this is from
# TODO: Why a different implementation than IoU from utils?
def mean_iou(y_true, y_pred, **kwargs):
"""
Compute mean Intersection over Union of two segmentation masks, via Keras.
@ -311,6 +315,7 @@ def iou_vahid(y_true, y_pred):
return K.mean(iou)
# TODO: copy from utils?
def IoU_metric(Yi, y_predi):
# mean Intersection over Union
# Mean IoU = TP/(FN + TP + FP)
@ -337,6 +342,7 @@ def IoU_metric_keras(y_true, y_pred):
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
# TODO: unused, remove?
def jaccard_distance_loss(y_true, y_pred, smooth=100):
"""
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)

View file

@ -5,6 +5,8 @@ from tensorflow.keras.layers import *
from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2
from eynollah.patch_encoder import Patches, PatchEncoder
##mlp_head_units = [512, 256]#[2048, 1024]
###projection_dim = 64
##transformer_layers = 2#8
@ -38,87 +40,6 @@ def mlp(x, hidden_units, dropout_rate):
x = layers.Dropout(dropout_rate)(x)
return x
class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class Patches_old(layers.Layer):
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
#print(patches.shape,patch_dims,'patch_dims')
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def one_side_pad(x):

View file

@ -175,22 +175,94 @@ def config_params():
characters_txt_file = None # Directory of characters text file needed for cnn_rnn_ocr model training. The file ends with .txt
@ex.automain
def run(_config, n_classes, n_epochs, input_height,
input_width, weight_decay, weighted_loss,
index_start, dir_of_start_model, is_loss_soft_dice,
n_batch, patches, augmentation, flip_aug,
blur_aug, padding_white, padding_black, scaling, shifting, degrading,channels_shuffling,
brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes,
brightness, dir_train, data_is_provided, scaling_bluring,
scaling_brightness, scaling_binarization, rotation, rotation_not_90,
thetha, thetha_padd, scaling_flip, continue_training, transformer_projection_dim,
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
transformer_patchsize_x, transformer_patchsize_y,
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds,
dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin,
textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin,
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors, pepper_indexes, white_padds, skewing_amplitudes, max_len):
def run(
_config,
n_classes,
n_epochs,
input_height,
input_width,
weight_decay,
weighted_loss,
index_start,
dir_of_start_model,
is_loss_soft_dice,
n_batch,
patches,
augmentation,
flip_aug,
blur_aug,
padding_white,
padding_black,
scaling,
shifting,
degrading,
channels_shuffling,
brightening,
binarization,
adding_rgb_background,
adding_rgb_foreground,
add_red_textlines,
blur_k,
scales,
degrade_scales,
shuffle_indexes,
brightness,
dir_train,
data_is_provided,
scaling_bluring,
scaling_brightness,
scaling_binarization,
rotation,
rotation_not_90,
thetha,
thetha_padd,
scaling_flip,
continue_training,
transformer_projection_dim,
transformer_mlp_head_units,
transformer_layers,
transformer_num_heads,
transformer_cnn_first,
transformer_patchsize_x,
transformer_patchsize_y,
transformer_num_patches_xy,
backbone_type,
save_interval,
flip_index,
dir_eval,
dir_output,
pretraining,
learning_rate,
task,
f1_threshold_classification,
classification_classes_name,
dir_img_bin,
number_of_backgrounds_per_image,
dir_rgb_backgrounds,
dir_rgb_foregrounds,
characters_txt_file,
color_padding_rotation,
bin_deg,
image_inversion,
white_noise_strap,
textline_skewing,
textline_skewing_bin,
textline_left_in_depth,
textline_left_in_depth_bin,
textline_right_in_depth,
textline_right_in_depth_bin,
textline_up_in_depth,
textline_up_in_depth_bin,
textline_down_in_depth,
textline_down_in_depth_bin,
pepper_bin_aug,
pepper_aug,
padd_colors,
pepper_indexes,
white_padds,
skewing_amplitudes,
max_len,
):
if dir_rgb_backgrounds:
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
@ -201,6 +273,10 @@ def run(_config, n_classes, n_epochs, input_height,
list_all_possible_foreground_rgbs = os.listdir(dir_rgb_foregrounds)
else:
list_all_possible_foreground_rgbs = None
dir_seg = None
weights = None
model = None
if task == "segmentation" or task == "enhancement" or task == "binarization":
if data_is_provided:
@ -285,6 +361,7 @@ def run(_config, n_classes, n_epochs, input_height,
pass
else:
assert dir_seg is not None
for obj in os.listdir(dir_seg):
try:
label_obj = cv2.imread(dir_seg + '/' + obj)
@ -314,6 +391,8 @@ def run(_config, n_classes, n_epochs, input_height,
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
else:
raise ValueError("backbone_type must be 'nontransformer' or 'transformer'")
else:
index_start = 0
if backbone_type=='nontransformer':
@ -348,6 +427,7 @@ def run(_config, n_classes, n_epochs, input_height,
sys.exit(1)
model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
assert model is not None
#if you want to see the model structure just uncomment model summary.
model.summary()
@ -377,9 +457,7 @@ def run(_config, n_classes, n_epochs, input_height,
##score_best=[]
##score_best.append(0)
if save_interval:
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval:
@ -459,8 +537,7 @@ def run(_config, n_classes, n_epochs, input_height,
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
model.compile(optimizer=opt)
if save_interval:
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval:
@ -559,8 +636,7 @@ def run(_config, n_classes, n_epochs, input_height,
model.compile(loss="binary_crossentropy",
optimizer = opt_adam,metrics=['accuracy'])
if save_interval:
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
for i in range(n_epochs):
if save_interval:

File diff suppressed because it is too large Load diff