eynollah models package

This commit is contained in:
kba 2025-10-22 16:38:05 +02:00
parent 04bc4a63d0
commit 883546a6b8
3 changed files with 64 additions and 20 deletions

View file

@ -1,21 +1,19 @@
from dataclasses import dataclass
from typing import List, Tuple
from pathlib import Path
from typing import List, Set, Tuple
import click
from eynollah.model_zoo.default_specs import MODELS_VERSION
from .model_zoo import EynollahModelZoo
@dataclass()
class EynollahCliCtx():
model_basedir: str
model_overrides: List[Tuple[str, str, str]]
class EynollahCliCtx:
model_zoo: EynollahModelZoo
@click.group()
def models_cli():
"""
Organize models for the various runners in eynollah.
"""
@models_cli.command('list')
@click.pass_context
@click.option(
"--model",
"-m",
@ -32,18 +30,64 @@ def models_cli():
type=(str, str, str),
multiple=True,
)
@click.pass_context
def list_models(
def models_cli(
ctx,
model_basedir: str,
model_overrides: List[Tuple[str, str, str]],
):
"""
List all the models in the zoo
Organize models for the various runners in eynollah.
"""
ctx.obj = EynollahCliCtx(
model_basedir=model_basedir,
model_overrides=model_overrides
)
print(EynollahModelZoo(basedir=ctx.obj.model_basedir, model_overrides=ctx.obj.model_overrides))
ctx.obj = EynollahCliCtx(model_zoo=EynollahModelZoo(basedir=model_basedir, model_overrides=model_overrides))
@models_cli.command('list')
@click.pass_context
def list_models(
ctx,
):
"""
List all the models in the zoo
"""
print(ctx.obj.model_zoo)
@models_cli.command('package')
@click.option(
'--set-version', '-V', 'version', help="Version to use for packaging", default=MODELS_VERSION, show_default=True
)
@click.argument('output_dir')
@click.pass_context
def package(
ctx,
version,
output_dir,
):
"""
Generate shell code to copy all the models in the zoo into properly named folders in OUTPUT_DIR for distribution.
eynollah models -m SRC package OUTPUT_DIR
SRC should contain a directory "models_eynollah" containing all the models.
"""
mkdirs: Set[Path] = set([])
copies: Set[Tuple[Path, Path]] = set([])
for spec in ctx.obj.model_zoo.specs.specs:
# skip these as they are dependent on the ocr model
if spec.category in ('num_to_char', 'characters'):
continue
src: Path = ctx.obj.model_zoo.model_path(spec.category, spec.variant)
# Only copy the top-most directory relative to models_eynollah
while src.parent.name != 'models_eynollah':
src = src.parent
for dist in spec.dists:
dist_dir = Path(f"{output_dir}/models_{dist}_{version}/models_eynollah")
copies.add((src, dist_dir))
mkdirs.add(dist_dir)
for dir in mkdirs:
print(f"mkdir -p {dir}")
for (src, dst) in copies:
print(f"cp -r {src} {dst}")
for dir in mkdirs:
zip_path = Path(f'../{dir.parent.name}.zip')
print(f"(cd {dir}/..; zip -r {zip_path} models_eynollah)")

View file

@ -166,7 +166,7 @@ class EynollahModelZoo:
else f'No, download {spec.dist_url}',
# self.model_path(spec.category, spec.variant),
]
for spec in sorted(self.specs.specs, key=lambda x: x.category + '0' + x.variant)
for spec in self.specs.specs
],
headers=[
'Type',

View file

@ -26,7 +26,7 @@ class EynollahModelSpecSet():
specs: List[EynollahModelSpec]
def __init__(self, specs: List[EynollahModelSpec]) -> None:
self.specs = specs
self.specs = sorted(specs, key=lambda x: x.category + '0' + x.variant)
self.categories: Set[str] = set([spec.category for spec in self.specs])
self.variants: Dict[str, Set[str]] = {
spec.category: set([x.variant for x in self.specs if x.category == spec.category])