From 1c043c586a972c4088d204b179b37d64eb44a39f Mon Sep 17 00:00:00 2001 From: kba Date: Wed, 1 Oct 2025 18:52:11 +0200 Subject: [PATCH] eynollah-training: all training CLI into single click group --- pyproject.toml | 1 + ..._model_load_pretrained_weights_and_save.py | 6 ++--- src/eynollah/training/cli.py | 26 +++++++++++++++++++ .../training/generate_gt_for_training.py | 3 --- src/eynollah/training/inference.py | 11 +++----- src/eynollah/training/train.py | 11 +++++--- 6 files changed, 41 insertions(+), 17 deletions(-) create mode 100644 src/eynollah/training/cli.py diff --git a/pyproject.toml b/pyproject.toml index ec3e5f8..ec99c99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ classifiers = [ [project.scripts] eynollah = "eynollah.cli:main" +eynollah-training = "eynollah.training.cli:main" ocrd-eynollah-segment = "eynollah.ocrd_cli:main" ocrd-sbb-binarize = "eynollah.ocrd_cli_binarization:main" diff --git a/src/eynollah/training/build_model_load_pretrained_weights_and_save.py b/src/eynollah/training/build_model_load_pretrained_weights_and_save.py index ce3d955..40fc1fe 100644 --- a/src/eynollah/training/build_model_load_pretrained_weights_and_save.py +++ b/src/eynollah/training/build_model_load_pretrained_weights_and_save.py @@ -1,5 +1,5 @@ +import click import tensorflow as tf -from tensorflow.keras.optimizers import * from .models import resnet50_unet @@ -8,8 +8,8 @@ def configuration(): gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options)) - -if __name__ == '__main__': +@click.command() +def build_model_load_pretrained_weights_and_save(): n_classes = 2 input_height = 224 input_width = 448 diff --git a/src/eynollah/training/cli.py b/src/eynollah/training/cli.py new file mode 100644 index 0000000..8ab754d --- /dev/null +++ b/src/eynollah/training/cli.py @@ -0,0 +1,26 @@ +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import click +import sys + +from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save +from .generate_gt_for_training import main as generate_gt_cli +from .inference import main as inference_cli +from .train import ex + +@click.command(context_settings=dict( + ignore_unknown_options=True, +)) +@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED) +def train_cli(sacred_args): + ex.run_commandline([sys.argv[0]] + list(sacred_args)) + +@click.group('training') +def main(): + pass + +main.add_command(build_model_load_pretrained_weights_and_save) +main.add_command(generate_gt_cli, 'generate-gt') +main.add_command(inference_cli, 'inference') +main.add_command(train_cli, 'train') diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index 3fd93ae..693cab8 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -581,6 +581,3 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out): # Draw the text draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font) image_text.save(os.path.join(dir_out, f_name+'.png')) - -if __name__ == "__main__": - main() diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index 998c8fc..3fa8fd6 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -20,7 +20,10 @@ from .gt_gen_utils import ( resize_image, update_list_and_return_first_with_length_bigger_than_one ) -from .models import PatchEncoder, Patches +from .models import ( + PatchEncoder, + Patches +) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -675,9 +678,3 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) x.run() -if __name__=="__main__": - main() - - - - diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 527bca6..97736e0 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -2,9 +2,13 @@ import os import sys import json -from eynollah.training.metrics import soft_dice_loss, weighted_categorical_crossentropy +import click -from .models import ( +from eynollah.training.metrics import ( + soft_dice_loss, + weighted_categorical_crossentropy +) +from eynollah.training.models import ( PatchEncoder, Patches, machine_based_reading_order_model, @@ -13,7 +17,7 @@ from .models import ( vit_resnet50_unet, vit_resnet50_unet_transformer_before_cnn ) -from .utils import ( +from eynollah.training.utils import ( data_gen, generate_arrays_from_folder_reading_order, generate_data_from_folder_evaluation, @@ -142,7 +146,6 @@ def config_params(): dir_rgb_backgrounds = None dir_rgb_foregrounds = None - @ex.automain def run(_config, n_classes, n_epochs, input_height, input_width, weight_decay, weighted_loss,