This commit is contained in:
kba 2025-12-19 14:57:10 +01:00
parent 4651000191
commit 9ccc495b4a
4 changed files with 505 additions and 755 deletions

View file

@ -147,6 +147,7 @@ def generalized_dice_loss(y_true, y_pred):
return 1 - generalized_dice_coeff2(y_true, y_pred) return 1 - generalized_dice_coeff2(y_true, y_pred)
# TODO: document where this is from
def soft_dice_loss(y_true, y_pred, epsilon=1e-6): def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
""" """
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
@ -175,6 +176,7 @@ def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
# TODO: document where this is from
def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False, def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=True, mean_per_class=False,
verbose=False): verbose=False):
""" """
@ -267,6 +269,8 @@ def seg_metrics(y_true, y_pred, metric_name, metric_type='standard', drop_last=T
return K.mean(non_zero_sum / non_zero_count) return K.mean(non_zero_sum / non_zero_count)
# TODO: document where this is from
# TODO: Why a different implementation than IoU from utils?
def mean_iou(y_true, y_pred, **kwargs): def mean_iou(y_true, y_pred, **kwargs):
""" """
Compute mean Intersection over Union of two segmentation masks, via Keras. Compute mean Intersection over Union of two segmentation masks, via Keras.
@ -311,6 +315,7 @@ def iou_vahid(y_true, y_pred):
return K.mean(iou) return K.mean(iou)
# TODO: copy from utils?
def IoU_metric(Yi, y_predi): def IoU_metric(Yi, y_predi):
# mean Intersection over Union # mean Intersection over Union
# Mean IoU = TP/(FN + TP + FP) # Mean IoU = TP/(FN + TP + FP)
@ -337,6 +342,7 @@ def IoU_metric_keras(y_true, y_pred):
return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess)) return IoU_metric(y_true.eval(session=sess), y_pred.eval(session=sess))
# TODO: unused, remove?
def jaccard_distance_loss(y_true, y_pred, smooth=100): def jaccard_distance_loss(y_true, y_pred, smooth=100):
""" """
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)

View file

@ -5,6 +5,8 @@ from tensorflow.keras.layers import *
from tensorflow.keras import layers from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2 from tensorflow.keras.regularizers import l2
from eynollah.patch_encoder import Patches, PatchEncoder
##mlp_head_units = [512, 256]#[2048, 1024] ##mlp_head_units = [512, 256]#[2048, 1024]
###projection_dim = 64 ###projection_dim = 64
##transformer_layers = 2#8 ##transformer_layers = 2#8
@ -38,87 +40,6 @@ def mlp(x, hidden_units, dropout_rate):
x = layers.Dropout(dropout_rate)(x) x = layers.Dropout(dropout_rate)(x)
return x return x
class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class Patches_old(layers.Layer):
def __init__(self, patch_size):#__init__(self, **kwargs):#:__init__(self, patch_size):#__init__(self, **kwargs):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
#print(patches.shape,patch_dims,'patch_dims')
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size': self.patch_size,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def one_side_pad(x): def one_side_pad(x):

View file

@ -175,22 +175,94 @@ def config_params():
characters_txt_file = None # Directory of characters text file needed for cnn_rnn_ocr model training. The file ends with .txt characters_txt_file = None # Directory of characters text file needed for cnn_rnn_ocr model training. The file ends with .txt
@ex.automain @ex.automain
def run(_config, n_classes, n_epochs, input_height, def run(
input_width, weight_decay, weighted_loss, _config,
index_start, dir_of_start_model, is_loss_soft_dice, n_classes,
n_batch, patches, augmentation, flip_aug, n_epochs,
blur_aug, padding_white, padding_black, scaling, shifting, degrading,channels_shuffling, input_height,
brightening, binarization, adding_rgb_background, adding_rgb_foreground, add_red_textlines, blur_k, scales, degrade_scales,shuffle_indexes, input_width,
brightness, dir_train, data_is_provided, scaling_bluring, weight_decay,
scaling_brightness, scaling_binarization, rotation, rotation_not_90, weighted_loss,
thetha, thetha_padd, scaling_flip, continue_training, transformer_projection_dim, index_start,
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first, dir_of_start_model,
transformer_patchsize_x, transformer_patchsize_y, is_loss_soft_dice,
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output, n_batch,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, patches,
dir_rgb_foregrounds, characters_txt_file, color_padding_rotation, bin_deg, image_inversion, white_noise_strap, textline_skewing, textline_skewing_bin, augmentation,
textline_left_in_depth, textline_left_in_depth_bin, textline_right_in_depth, textline_right_in_depth_bin, textline_up_in_depth, textline_up_in_depth_bin, flip_aug,
textline_down_in_depth, textline_down_in_depth_bin, pepper_bin_aug, pepper_aug, padd_colors, pepper_indexes, white_padds, skewing_amplitudes, max_len): blur_aug,
padding_white,
padding_black,
scaling,
shifting,
degrading,
channels_shuffling,
brightening,
binarization,
adding_rgb_background,
adding_rgb_foreground,
add_red_textlines,
blur_k,
scales,
degrade_scales,
shuffle_indexes,
brightness,
dir_train,
data_is_provided,
scaling_bluring,
scaling_brightness,
scaling_binarization,
rotation,
rotation_not_90,
thetha,
thetha_padd,
scaling_flip,
continue_training,
transformer_projection_dim,
transformer_mlp_head_units,
transformer_layers,
transformer_num_heads,
transformer_cnn_first,
transformer_patchsize_x,
transformer_patchsize_y,
transformer_num_patches_xy,
backbone_type,
save_interval,
flip_index,
dir_eval,
dir_output,
pretraining,
learning_rate,
task,
f1_threshold_classification,
classification_classes_name,
dir_img_bin,
number_of_backgrounds_per_image,
dir_rgb_backgrounds,
dir_rgb_foregrounds,
characters_txt_file,
color_padding_rotation,
bin_deg,
image_inversion,
white_noise_strap,
textline_skewing,
textline_skewing_bin,
textline_left_in_depth,
textline_left_in_depth_bin,
textline_right_in_depth,
textline_right_in_depth_bin,
textline_up_in_depth,
textline_up_in_depth_bin,
textline_down_in_depth,
textline_down_in_depth_bin,
pepper_bin_aug,
pepper_aug,
padd_colors,
pepper_indexes,
white_padds,
skewing_amplitudes,
max_len,
):
if dir_rgb_backgrounds: if dir_rgb_backgrounds:
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds) list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
@ -201,6 +273,10 @@ def run(_config, n_classes, n_epochs, input_height,
list_all_possible_foreground_rgbs = os.listdir(dir_rgb_foregrounds) list_all_possible_foreground_rgbs = os.listdir(dir_rgb_foregrounds)
else: else:
list_all_possible_foreground_rgbs = None list_all_possible_foreground_rgbs = None
dir_seg = None
weights = None
model = None
if task == "segmentation" or task == "enhancement" or task == "binarization": if task == "segmentation" or task == "enhancement" or task == "binarization":
if data_is_provided: if data_is_provided:
@ -285,6 +361,7 @@ def run(_config, n_classes, n_epochs, input_height,
pass pass
else: else:
assert dir_seg is not None
for obj in os.listdir(dir_seg): for obj in os.listdir(dir_seg):
try: try:
label_obj = cv2.imread(dir_seg + '/' + obj) label_obj = cv2.imread(dir_seg + '/' + obj)
@ -314,6 +391,8 @@ def run(_config, n_classes, n_epochs, input_height,
model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)}) model = load_model(dir_of_start_model, compile=True, custom_objects={'loss': weighted_categorical_crossentropy(weights)})
if not is_loss_soft_dice and not weighted_loss: if not is_loss_soft_dice and not weighted_loss:
model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) model = load_model(dir_of_start_model , compile=True,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
else:
raise ValueError("backbone_type must be 'nontransformer' or 'transformer'")
else: else:
index_start = 0 index_start = 0
if backbone_type=='nontransformer': if backbone_type=='nontransformer':
@ -348,6 +427,7 @@ def run(_config, n_classes, n_epochs, input_height,
sys.exit(1) sys.exit(1)
model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining) model = vit_resnet50_unet_transformer_before_cnn(n_classes, transformer_patchsize_x, transformer_patchsize_y, num_patches, transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_projection_dim, input_height, input_width, task, weight_decay, pretraining)
assert model is not None
#if you want to see the model structure just uncomment model summary. #if you want to see the model structure just uncomment model summary.
model.summary() model.summary()
@ -377,9 +457,7 @@ def run(_config, n_classes, n_epochs, input_height,
##score_best=[] ##score_best=[]
##score_best.append(0) ##score_best.append(0)
if save_interval: save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
for i in tqdm(range(index_start, n_epochs + index_start)): for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval: if save_interval:
@ -459,8 +537,7 @@ def run(_config, n_classes, n_epochs, input_height,
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule) opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)#1e-4)#(lr_schedule)
model.compile(optimizer=opt) model.compile(optimizer=opt)
if save_interval: save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
for i in tqdm(range(index_start, n_epochs + index_start)): for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval: if save_interval:
@ -559,8 +636,7 @@ def run(_config, n_classes, n_epochs, input_height,
model.compile(loss="binary_crossentropy", model.compile(loss="binary_crossentropy",
optimizer = opt_adam,metrics=['accuracy']) optimizer = opt_adam,metrics=['accuracy'])
if save_interval: save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config) if save_interval else None
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
for i in range(n_epochs): for i in range(n_epochs):
if save_interval: if save_interval:

File diff suppressed because it is too large Load diff