eynollah-training: all training CLI into single click group

This commit is contained in:
kba 2025-10-01 18:52:11 +02:00
parent 690d47444c
commit 1c043c586a
6 changed files with 41 additions and 17 deletions

View file

@ -31,6 +31,7 @@ classifiers = [
[project.scripts]
eynollah = "eynollah.cli:main"
eynollah-training = "eynollah.training.cli:main"
ocrd-eynollah-segment = "eynollah.ocrd_cli:main"
ocrd-sbb-binarize = "eynollah.ocrd_cli_binarization:main"

View file

@ -1,5 +1,5 @@
import click
import tensorflow as tf
from tensorflow.keras.optimizers import *
from .models import resnet50_unet
@ -8,8 +8,8 @@ def configuration():
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
if __name__ == '__main__':
@click.command()
def build_model_load_pretrained_weights_and_save():
n_classes = 2
input_height = 224
input_width = 448

View file

@ -0,0 +1,26 @@
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import click
import sys
from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save
from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli
from .train import ex
@click.command(context_settings=dict(
ignore_unknown_options=True,
))
@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED)
def train_cli(sacred_args):
ex.run_commandline([sys.argv[0]] + list(sacred_args))
@click.group('training')
def main():
pass
main.add_command(build_model_load_pretrained_weights_and_save)
main.add_command(generate_gt_cli, 'generate-gt')
main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train')

View file

@ -581,6 +581,3 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
# Draw the text
draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font)
image_text.save(os.path.join(dir_out, f_name+'.png'))
if __name__ == "__main__":
main()

View file

@ -20,7 +20,10 @@ from .gt_gen_utils import (
resize_image,
update_list_and_return_first_with_length_bigger_than_one
)
from .models import PatchEncoder, Patches
from .models import (
PatchEncoder,
Patches
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@ -675,9 +678,3 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
x.run()
if __name__=="__main__":
main()

View file

@ -2,9 +2,13 @@ import os
import sys
import json
from eynollah.training.metrics import soft_dice_loss, weighted_categorical_crossentropy
import click
from .models import (
from eynollah.training.metrics import (
soft_dice_loss,
weighted_categorical_crossentropy
)
from eynollah.training.models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
@ -13,7 +17,7 @@ from .models import (
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn
)
from .utils import (
from eynollah.training.utils import (
data_gen,
generate_arrays_from_folder_reading_order,
generate_data_from_folder_evaluation,
@ -142,7 +146,6 @@ def config_params():
dir_rgb_backgrounds = None
dir_rgb_foregrounds = None
@ex.automain
def run(_config, n_classes, n_epochs, input_height,
input_width, weight_decay, weighted_loss,