diff --git a/src/eynollah/cli_models.py b/src/eynollah/cli_models.py index b67a3ef..595c499 100644 --- a/src/eynollah/cli_models.py +++ b/src/eynollah/cli_models.py @@ -1,21 +1,19 @@ from dataclasses import dataclass -from typing import List, Tuple +from pathlib import Path +from typing import List, Set, Tuple import click + +from eynollah.model_zoo.default_specs import MODELS_VERSION from .model_zoo import EynollahModelZoo + @dataclass() -class EynollahCliCtx(): - model_basedir: str - model_overrides: List[Tuple[str, str, str]] +class EynollahCliCtx: + model_zoo: EynollahModelZoo @click.group() -def models_cli(): - """ - Organize models for the various runners in eynollah. - """ - -@models_cli.command('list') +@click.pass_context @click.option( "--model", "-m", @@ -32,18 +30,64 @@ def models_cli(): type=(str, str, str), multiple=True, ) -@click.pass_context -def list_models( +def models_cli( ctx, model_basedir: str, model_overrides: List[Tuple[str, str, str]], ): """ - List all the models in the zoo + Organize models for the various runners in eynollah. """ - ctx.obj = EynollahCliCtx( - model_basedir=model_basedir, - model_overrides=model_overrides - ) - print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides)) + ctx.obj = EynollahCliCtx(model_zoo=EynollahModelZoo(basedir=model_basedir, model_overrides=model_overrides)) + +@models_cli.command('list') +@click.pass_context +def list_models( + ctx, +): + """ + List all the models in the zoo + """ + print(ctx.obj.model_zoo) + + +@models_cli.command('package') +@click.option( + '--set-version', '-V', 'version', help="Version to use for packaging", default=MODELS_VERSION, show_default=True +) +@click.argument('output_dir') +@click.pass_context +def package( + ctx, + version, + output_dir, +): + """ + Generate shell code to copy all the models in the zoo into properly named folders in OUTPUT_DIR for distribution. + + eynollah models -m SRC package OUTPUT_DIR + + SRC should contain a directory "models_eynollah" containing all the models. + """ + mkdirs: Set[Path] = set([]) + copies: Set[Tuple[Path, Path]] = set([]) + for spec in ctx.obj.model_zoo.specs.specs: + # skip these as they are dependent on the ocr model + if spec.category in ('num_to_char', 'characters'): + continue + src: Path = ctx.obj.model_zoo.model_path(spec.category, spec.variant) + # Only copy the top-most directory relative to models_eynollah + while src.parent.name != 'models_eynollah': + src = src.parent + for dist in spec.dists: + dist_dir = Path(f"{output_dir}/models_{dist}_{version}/models_eynollah") + copies.add((src, dist_dir)) + mkdirs.add(dist_dir) + for dir in mkdirs: + print(f"mkdir -p {dir}") + for (src, dst) in copies: + print(f"cp -r {src} {dst}") + for dir in mkdirs: + zip_path = Path(f'../{dir.parent.name}.zip') + print(f"(cd {dir}/..; zip -r {zip_path} models_eynollah)") diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 7cfaa3a..8948a1f 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -166,7 +166,7 @@ class EynollahModelZoo: else f'No, download {spec.dist_url}', # self.model_path(spec.category, spec.variant), ] - for spec in sorted(self.specs.specs, key=lambda x: x.category + '0' + x.variant) + for spec in self.specs.specs ], headers=[ 'Type', diff --git a/src/eynollah/model_zoo/specs.py b/src/eynollah/model_zoo/specs.py index 4f8cffa..322afa4 100644 --- a/src/eynollah/model_zoo/specs.py +++ b/src/eynollah/model_zoo/specs.py @@ -26,7 +26,7 @@ class EynollahModelSpecSet(): specs: List[EynollahModelSpec] def __init__(self, specs: List[EynollahModelSpec]) -> None: - self.specs = specs + self.specs = sorted(specs, key=lambda x: x.category + '0' + x.variant) self.categories: Set[str] = set([spec.category for spec in self.specs]) self.variants: Dict[str, Set[str]] = { spec.category: set([x.variant for x in self.specs if x.category == spec.category])