mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-06 14:39:55 +02:00
eynollah-training: all training CLI into single click group
This commit is contained in:
parent
690d47444c
commit
1c043c586a
6 changed files with 41 additions and 17 deletions
|
@ -31,6 +31,7 @@ classifiers = [
|
|||
|
||||
[project.scripts]
|
||||
eynollah = "eynollah.cli:main"
|
||||
eynollah-training = "eynollah.training.cli:main"
|
||||
ocrd-eynollah-segment = "eynollah.ocrd_cli:main"
|
||||
ocrd-sbb-binarize = "eynollah.ocrd_cli_binarization:main"
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import click
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.optimizers import *
|
||||
|
||||
from .models import resnet50_unet
|
||||
|
||||
|
@ -8,8 +8,8 @@ def configuration():
|
|||
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
|
||||
session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@click.command()
|
||||
def build_model_load_pretrained_weights_and_save():
|
||||
n_classes = 2
|
||||
input_height = 224
|
||||
input_width = 448
|
||||
|
|
26
src/eynollah/training/cli.py
Normal file
26
src/eynollah/training/cli.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
import os
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
import click
|
||||
import sys
|
||||
|
||||
from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save
|
||||
from .generate_gt_for_training import main as generate_gt_cli
|
||||
from .inference import main as inference_cli
|
||||
from .train import ex
|
||||
|
||||
@click.command(context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
))
|
||||
@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED)
|
||||
def train_cli(sacred_args):
|
||||
ex.run_commandline([sys.argv[0]] + list(sacred_args))
|
||||
|
||||
@click.group('training')
|
||||
def main():
|
||||
pass
|
||||
|
||||
main.add_command(build_model_load_pretrained_weights_and_save)
|
||||
main.add_command(generate_gt_cli, 'generate-gt')
|
||||
main.add_command(inference_cli, 'inference')
|
||||
main.add_command(train_cli, 'train')
|
|
@ -581,6 +581,3 @@ def visualize_ocr_text(xml_file, dir_xml, dir_out):
|
|||
# Draw the text
|
||||
draw.text((text_x, text_y), ocr_texts[index], fill="black", font=font)
|
||||
image_text.save(os.path.join(dir_out, f_name+'.png'))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -20,7 +20,10 @@ from .gt_gen_utils import (
|
|||
resize_image,
|
||||
update_list_and_return_first_with_length_bigger_than_one
|
||||
)
|
||||
from .models import PatchEncoder, Patches
|
||||
from .models import (
|
||||
PatchEncoder,
|
||||
Patches
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
@ -675,9 +678,3 @@ def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_fil
|
|||
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
|
||||
x.run()
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -2,9 +2,13 @@ import os
|
|||
import sys
|
||||
import json
|
||||
|
||||
from eynollah.training.metrics import soft_dice_loss, weighted_categorical_crossentropy
|
||||
import click
|
||||
|
||||
from .models import (
|
||||
from eynollah.training.metrics import (
|
||||
soft_dice_loss,
|
||||
weighted_categorical_crossentropy
|
||||
)
|
||||
from eynollah.training.models import (
|
||||
PatchEncoder,
|
||||
Patches,
|
||||
machine_based_reading_order_model,
|
||||
|
@ -13,7 +17,7 @@ from .models import (
|
|||
vit_resnet50_unet,
|
||||
vit_resnet50_unet_transformer_before_cnn
|
||||
)
|
||||
from .utils import (
|
||||
from eynollah.training.utils import (
|
||||
data_gen,
|
||||
generate_arrays_from_folder_reading_order,
|
||||
generate_data_from_folder_evaluation,
|
||||
|
@ -142,7 +146,6 @@ def config_params():
|
|||
dir_rgb_backgrounds = None
|
||||
dir_rgb_foregrounds = None
|
||||
|
||||
|
||||
@ex.automain
|
||||
def run(_config, n_classes, n_epochs, input_height,
|
||||
input_width, weight_decay, weighted_loss,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue