training: add CLI command convert

- move `train_cli` from cli.py to train.py,
  add docstring
- add `convert_cli`:
  - load any (supported) model format
    (i.e. not exported TF-Serving or ONNX)
  - if SavedModel format with `config.json` present,
    and `--rebuild` is requested, create new model
    from `models.get_model()` for this configuration,
    and load weights
  - if model type is `cnn-rnn-ocr` and configuration
    is still for training (`ctc_loss`), then extract
    inference model
  - apply requested `--format` conversion:
    HDF5, Keras native, Keras SavedModel, TF-Serving SavedModel
    or ONNX
  - if output format is directory (i.e. SavedModel),
    then copy over `config.json`, too
- reload-models-v0.8.mk:
  - adapt recipe for converter CLI (i.e. `--format tf-serving`
    w/ `--rebuild` if possible)
  - add targets for other useful data formats
  - extend list of model names to all current models
    (as all benefit from TF-Serving export)
  - cancel ONNX conversion for vision transformer models
    (as these do not work, yet)
This commit is contained in:
Robert Sachunsky 2026-05-28 17:48:21 +02:00
parent 62b55a3809
commit f833a516e7
5 changed files with 187 additions and 38 deletions

View file

@ -7,17 +7,11 @@ import sys
from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save
from .generate_gt_for_training import main as generate_gt_cli from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli from .inference import main as inference_cli
from .train import ex from .train import train_cli
from .convert import convert_cli
from .extract_line_gt import linegt_cli from .extract_line_gt import linegt_cli
from .weights_ensembling import ensemble_cli from .weights_ensembling import ensemble_cli
@click.command(context_settings=dict(
ignore_unknown_options=True,
))
@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED)
def train_cli(sacred_args):
ex.run_commandline([sys.argv[0]] + list(sacred_args))
@click.group('training') @click.group('training')
def main(): def main():
pass pass
@ -26,5 +20,6 @@ main.add_command(build_model_load_pretrained_weights_and_save)
main.add_command(generate_gt_cli, 'generate-gt') main.add_command(generate_gt_cli, 'generate-gt')
main.add_command(inference_cli, 'inference') main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train') main.add_command(train_cli, 'train')
main.add_command(convert_cli, 'convert')
main.add_command(linegt_cli, 'export_textline_images_and_text') main.add_command(linegt_cli, 'export_textline_images_and_text')
main.add_command(ensemble_cli, 'ensembling') main.add_command(ensemble_cli, 'ensembling')

View file

@ -0,0 +1,107 @@
import os
from pathlib import Path
from shutil import copy2
import logging
import click
@click.command(context_settings=dict(
help_option_names=['-h', '--help'],
show_default=True))
@click.option(
"--rebuild",
"-r",
help="build new model from code and then load existing weights (requires input in SavedModel directory format with config.json present)",
is_flag=True
)
@click.option(
"--format",
"-f",
"format_",
help="data format to convert to",
type=click.Choice(["hdf5", "keras", "tf", "tf-serving", "onnx"]),
default="tf"
)
@click.option(
"--in",
"-i",
"in_",
help="path to input model (file in hdf5 / keras format, or directory in tf format)",
required=True,
type=click.Path(exists=True, dir_okay=True)
)
@click.option(
"--out",
"-o",
help="path to output model (file in hdf5 / keras / onnx format, or directory in tf / tf-serving format)",
required=True,
type=click.Path(exists=False, dir_okay=True)
)
def convert_cli(rebuild, format_, in_, out):
"""
convert models for inference
Load model from path, optionally by rebuilding, convert to output format and write model to path.
"""
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.models import Model as KerasModel
model_path = Path(in_)
config_path = model_path / "config.json"
if model_path.is_dir():
assert (model_path / "keras_metadata.pb").exists(), (
"input directory must be Keras model in SavedModel format")
if rebuild:
from .train import ex
from .models import get_model
assert config_path.exists(), (
"rebuilding requires input model in SavedModel format with config.json")
# merge defaults with existing config file
ex.add_config(str(config_path))
# some models deviate between training and inference
ex.add_config(inference=True)
# just retrieve final config (via pseudo-run)
ex.main(lambda: 0)
config = ex.run(options={'--loglevel': 'ERROR'}).config
# use the config to capture the model builder
model = get_model(config, logging.root)
model.load_weights(model_path).assert_existing_objects_matched().expect_partial()
else:
model = load_model(model_path, compile=False)
if isinstance(model, KerasModel):
# cnn-rnn-ocr task deviates between training and inference
try:
model.get_layer(name='ctc_loss')
except ValueError:
pass
else:
model = KerasModel(
model.get_layer(name='image').input,
model.get_layer(name='dense2').output)
if format_ in ["hdf5", "keras", "tf"]:
kwargs = {"save_format": {"hdf5": "h5"}.get(format_, format_)}
if format_ != "keras":
kwargs["include_optimizer"] = False
model.save(out, **kwargs)
elif format_ == "tf-serving":
model.export(out)
elif format_ == "onnx":
import tf2onnx
tf2onnx.convert.from_keras(model, opset=18, output_path=out)
else:
raise ValueError("unknown output format '%s'" % format_)
# copy config.json if possible
if config_path.exists() and format_ in ['tf', 'tf-serving']:
copy2(config_path, Path(out) / config_path.name)

View file

@ -4,39 +4,65 @@ MODELS_SRC = models_eynollah
MODELS_DST = reloaded/models_eynollah MODELS_DST = reloaded/models_eynollah
# $(MODELS_DST)/eynollah-binarization_20210425 \ # eynollah-main-regions-aug-rotation_20210425
# $(MODELS_DST)/eynollah-column-classifier_20210425 \ # eynollah-main-regions-aug-scaling_20210425
# $(MODELS_DST)/eynollah-enhancement_20210425 \ # eynollah-main-regions-ensembled_20210425
# $(MODELS_DST)/eynollah-main-regions-aug-rotation_20210425 \ # eynollah-main-regions_20220314
# $(MODELS_DST)/eynollah-main-regions-aug-scaling_20210425 \ # eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18
# $(MODELS_DST)/eynollah-main-regions-ensembled_20210425 \ # eynollah-tables_20210319
# $(MODELS_DST)/eynollah-main-regions_20220314 \
# $(MODELS_DST)/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18 \
# $(MODELS_DST)/eynollah-tables_20210319 \
# $(MODELS_DST)/model_eynollah_ocr_cnnrnn_20250930 \
RELOADABLE_MODELS = \ CURRENT_MODELS :=
$(MODELS_DST)/model_eynollah_page_extraction_20250915 \ CURRENT_MODELS += model_eynollah_page_extraction_20250915
$(MODELS_DST)/model_eynollah_reading_order_20250824 \ CURRENT_MODELS += model_eynollah_reading_order_20250824
$(MODELS_DST)/modelens_e_l_all_sp_0_1_2_3_4_171024 \ CURRENT_MODELS += modelens_e_l_all_sp_0_1_2_3_4_171024
$(MODELS_DST)/modelens_full_lay_1__4_3_091124 \ CURRENT_MODELS += modelens_full_lay_1__4_3_091124
$(MODELS_DST)/modelens_table_0t4_201124 \ CURRENT_MODELS += modelens_table_0t4_201124
$(MODELS_DST)/modelens_textline_0_1__2_4_16092024 CURRENT_MODELS += modelens_textline_0_1__2_4_16092024
CURRENT_MODELS += model_eynollah_ocr_cnnrnn_20250930
CURRENT_MODELS += eynollah-binarization_20210425
CURRENT_MODELS += eynollah-column-classifier_20210425
CURRENT_MODELS += eynollah-enhancement_20210425
all: $(RELOADABLE_MODELS) all: tf-serving
tf-serving: $(CURRENT_MODELS:%=$(MODELS_DST)/%)
keras: $(CURRENT_MODELS:%=$(MODELS_DST)/%.keras)
hdf5: $(CURRENT_MODELS:%=$(MODELS_DST)/%.h5)
onnx: $(CURRENT_MODELS:%=$(MODELS_DST)/%.onnx)
$(MODELS_DST)/%: $(MODELS_SRC)/% $(MODELS_DST)/%: $(MODELS_SRC)/%
test -e $</config.json || exit 1 eynollah-training convert \
{ mkdir -p $@ \ $(and $(wildcard $</config.json),--rebuild) \
&& eynollah-training train --force \ --in $< \
with $</config.json \ --format tf-serving \
reload_weights=True \ --out $@ \
continue_training=False \ 2>&1 | tee $(notdir $<).tf-serving.log
dir_output=$(dir $@) \
dir_of_start_model=$< \ $(MODELS_DST)/%.keras: $(MODELS_SRC)/%
&& cp $</config.json $@/config.json \ eynollah-training convert \
|| { rm -rf $@; false; }; } \ $(and $(wildcard $</config.json),--rebuild) \
2>&1 | tee $(notdir $<).log --in $< \
--format keras \
--out $@ \
2>&1 | tee $(notdir $<).keras.log
$(MODELS_DST)/%.h5: $(MODELS_SRC)/%
eynollah-training convert \
$(and $(wildcard $</config.json),--rebuild) \
--in $< \
--format hdf5 \
--out $@ \
2>&1 | tee $(notdir $<).hdf5.log
$(MODELS_DST)/%.onnx: $(MODELS_SRC)/%
if jq -e '.task == "segmentation" and .backbone_type == "transformer"' $</config.json &>/dev/null; then \
echo skipping $@: vision transformer architecture currently does not work with ONNX; else \
eynollah-training convert \
$(and $(wildcard $</config.json),--rebuild) \
--in $< \
--format onnx \
--out $@ \
2>&1 | tee $(notdir $<).onnx.log; fi
compare: compare:
for i in `find $(MODELS_DST) -mindepth 2`;do \ for i in `find $(MODELS_DST) -mindepth 2`;do \
@ -44,6 +70,5 @@ compare:
du -bs $$n $$i ; \ du -bs $$n $$i ; \
done done
clear: clear:
rm -rf $(MODELS_DST) rm -rf $(MODELS_DST)

View file

@ -2,6 +2,7 @@ import os
import sys import sys
import io import io
import json import json
import click
from tqdm import tqdm from tqdm import tqdm
import requests import requests
@ -791,3 +792,23 @@ def run(_config,
model_dir = os.path.join(dir_out,'model_best') model_dir = os.path.join(dir_out,'model_best')
model.save(model_dir) model.save(model_dir)
''' '''
@click.command(context_settings=dict(
ignore_unknown_options=True,
))
@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED)
def train_cli(sacred_args):
"""
train model on extracted GT
SACRED_ARGS as per CLI interface of Sacred, cf.
https://sacred.readthedocs.io/en/stable/command_line.html:
\b
To configure the learning task, pass the string `with`,
followed by any number of
- config JSON file paths
- parameter overrides in the form of key=value
(where the later settings will override the former).
"""
ex.run_commandline([sys.argv[0]] + list(sacred_args))

View file

@ -43,6 +43,7 @@ def run_ensembling(model_dirs, out_dir):
@click.option( @click.option(
"--in", "--in",
"-i", "-i",
"in_",
help="input directory of checkpoint models to be read", help="input directory of checkpoint models to be read",
multiple=True, multiple=True,
required=True, required=True,