mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-09 20:00:05 +02:00
saving model by steps is added to reading order and pixel wise segmentation use cases training
This commit is contained in:
parent
8ae42b7c6e
commit
1454bc4f58
1 changed files with 51 additions and 9 deletions
60
train.py
60
train.py
|
@ -13,8 +13,29 @@ from tensorflow.keras.models import load_model
|
|||
from tqdm import tqdm
|
||||
import json
|
||||
from sklearn.metrics import f1_score
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
|
||||
class SaveWeightsAfterSteps(Callback):
|
||||
def __init__(self, save_interval, save_path, _config):
|
||||
super(SaveWeightsAfterSteps, self).__init__()
|
||||
self.save_interval = save_interval
|
||||
self.save_path = save_path
|
||||
self.step_count = 0
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
self.step_count += 1
|
||||
|
||||
if self.step_count % self.save_interval ==0:
|
||||
save_file = f"{self.save_path}/model_step_{self.step_count}"
|
||||
#os.system('mkdir '+save_file)
|
||||
|
||||
self.model.save(save_file)
|
||||
|
||||
with open(os.path.join(os.path.join(save_path, "model_step_{self.step_count}"),"config.json"), "w") as fp:
|
||||
json.dump(_config, fp) # encode dict into JSON
|
||||
print(f"saved model as steps {self.step_count} to {save_file}")
|
||||
|
||||
|
||||
def configuration():
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
@ -93,7 +114,7 @@ def config_params():
|
|||
f1_threshold_classification = None # This threshold is used to consider models with an evaluation f1 scores bigger than it. The selected model weights undergo a weights ensembling. And avreage ensembled model will be written to output.
|
||||
classification_classes_name = None # Dictionary of classification classes names.
|
||||
backbone_type = None # As backbone we have 2 types of backbones. A vision transformer alongside a CNN and we call it "transformer" and only CNN called "nontransformer"
|
||||
|
||||
save_interval = None
|
||||
dir_img_bin = None
|
||||
number_of_backgrounds_per_image = 1
|
||||
dir_rgb_backgrounds = None
|
||||
|
@ -112,7 +133,7 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
thetha, scaling_flip, continue_training, transformer_projection_dim,
|
||||
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
|
||||
transformer_patchsize_x, transformer_patchsize_y,
|
||||
transformer_num_patches_xy, backbone_type, flip_index, dir_eval, dir_output,
|
||||
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
|
||||
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
|
||||
|
||||
if dir_rgb_backgrounds:
|
||||
|
@ -299,13 +320,27 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||
##score_best=[]
|
||||
##score_best.append(0)
|
||||
|
||||
if save_interval:
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
||||
|
||||
|
||||
for i in tqdm(range(index_start, n_epochs + index_start)):
|
||||
model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
if save_interval:
|
||||
model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1, callbacks=[save_weights_callback])
|
||||
else:
|
||||
model.fit(
|
||||
train_gen,
|
||||
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
|
||||
validation_data=val_gen,
|
||||
validation_steps=1,
|
||||
epochs=1)
|
||||
|
||||
model.save(os.path.join(dir_output,'model_'+str(i)))
|
||||
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
|
@ -392,8 +427,15 @@ def run(_config, n_classes, n_epochs, input_height,
|
|||
opt_adam = tf.keras.optimizers.Adam(learning_rate=0.0001)
|
||||
model.compile(loss="binary_crossentropy",
|
||||
optimizer = opt_adam,metrics=['accuracy'])
|
||||
|
||||
if save_interval:
|
||||
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)
|
||||
|
||||
for i in range(n_epochs):
|
||||
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
||||
if save_interval:
|
||||
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1, callbacks=[save_weights_callback])
|
||||
else:
|
||||
history = model.fit(generate_arrays_from_folder_reading_order(dir_flow_train_labels, dir_flow_train_imgs, n_batch, input_height, input_width, n_classes, thetha, augmentation), steps_per_epoch=num_rows / n_batch, verbose=1)
|
||||
model.save( os.path.join(dir_output,'model_'+str(i+indexer_start) ))
|
||||
|
||||
with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue