diff --git a/src/eynollah/training/cli.py b/src/eynollah/training/cli.py index ae14f04..ccabb82 100644 --- a/src/eynollah/training/cli.py +++ b/src/eynollah/training/cli.py @@ -7,17 +7,11 @@ import sys from .build_model_load_pretrained_weights_and_save import build_model_load_pretrained_weights_and_save from .generate_gt_for_training import main as generate_gt_cli from .inference import main as inference_cli -from .train import ex +from .train import train_cli +from .convert import convert_cli from .extract_line_gt import linegt_cli from .weights_ensembling import ensemble_cli -@click.command(context_settings=dict( - ignore_unknown_options=True, -)) -@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED) -def train_cli(sacred_args): - ex.run_commandline([sys.argv[0]] + list(sacred_args)) - @click.group('training') def main(): pass @@ -26,5 +20,6 @@ main.add_command(build_model_load_pretrained_weights_and_save) main.add_command(generate_gt_cli, 'generate-gt') main.add_command(inference_cli, 'inference') main.add_command(train_cli, 'train') +main.add_command(convert_cli, 'convert') main.add_command(linegt_cli, 'export_textline_images_and_text') main.add_command(ensemble_cli, 'ensembling') diff --git a/src/eynollah/training/convert.py b/src/eynollah/training/convert.py new file mode 100644 index 0000000..dd4271f --- /dev/null +++ b/src/eynollah/training/convert.py @@ -0,0 +1,107 @@ +import os +from pathlib import Path +from shutil import copy2 +import logging + +import click + +@click.command(context_settings=dict( + help_option_names=['-h', '--help'], + show_default=True)) +@click.option( + "--rebuild", + "-r", + help="build new model from code and then load existing weights (requires input in SavedModel directory format with config.json present)", + is_flag=True +) +@click.option( + "--format", + "-f", + "format_", + help="data format to convert to", + type=click.Choice(["hdf5", "keras", "tf", "tf-serving", "onnx"]), + default="tf" +) +@click.option( + "--in", + "-i", + "in_", + help="path to input model (file in hdf5 / keras format, or directory in tf format)", + required=True, + type=click.Path(exists=True, dir_okay=True) +) +@click.option( + "--out", + "-o", + help="path to output model (file in hdf5 / keras / onnx format, or directory in tf / tf-serving format)", + required=True, + type=click.Path(exists=False, dir_okay=True) +) +def convert_cli(rebuild, format_, in_, out): + """ + convert models for inference + + Load model from path, optionally by rebuilding, convert to output format and write model to path. + """ + os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 + from ocrd_utils import tf_disable_interactive_logs + tf_disable_interactive_logs() + + import tensorflow as tf + from tensorflow.keras.models import load_model + from tensorflow.keras.models import Model as KerasModel + + model_path = Path(in_) + config_path = model_path / "config.json" + if model_path.is_dir(): + assert (model_path / "keras_metadata.pb").exists(), ( + "input directory must be Keras model in SavedModel format") + if rebuild: + from .train import ex + from .models import get_model + + assert config_path.exists(), ( + "rebuilding requires input model in SavedModel format with config.json") + + # merge defaults with existing config file + ex.add_config(str(config_path)) + # some models deviate between training and inference + ex.add_config(inference=True) + # just retrieve final config (via pseudo-run) + ex.main(lambda: 0) + config = ex.run(options={'--loglevel': 'ERROR'}).config + # use the config to capture the model builder + model = get_model(config, logging.root) + model.load_weights(model_path).assert_existing_objects_matched().expect_partial() + else: + model = load_model(model_path, compile=False) + + if isinstance(model, KerasModel): + # cnn-rnn-ocr task deviates between training and inference + try: + model.get_layer(name='ctc_loss') + except ValueError: + pass + else: + model = KerasModel( + model.get_layer(name='image').input, + model.get_layer(name='dense2').output) + + if format_ in ["hdf5", "keras", "tf"]: + kwargs = {"save_format": {"hdf5": "h5"}.get(format_, format_)} + if format_ != "keras": + kwargs["include_optimizer"] = False + model.save(out, **kwargs) + elif format_ == "tf-serving": + model.export(out) + elif format_ == "onnx": + import tf2onnx + tf2onnx.convert.from_keras(model, opset=18, output_path=out) + else: + raise ValueError("unknown output format '%s'" % format_) + + # copy config.json if possible + if config_path.exists() and format_ in ['tf', 'tf-serving']: + copy2(config_path, Path(out) / config_path.name) + + diff --git a/src/eynollah/training/reload-models-v0.8.mk b/src/eynollah/training/reload-models-v0.8.mk index 07be7cf..9855f0f 100644 --- a/src/eynollah/training/reload-models-v0.8.mk +++ b/src/eynollah/training/reload-models-v0.8.mk @@ -4,39 +4,65 @@ MODELS_SRC = models_eynollah MODELS_DST = reloaded/models_eynollah -# $(MODELS_DST)/eynollah-binarization_20210425 \ -# $(MODELS_DST)/eynollah-column-classifier_20210425 \ -# $(MODELS_DST)/eynollah-enhancement_20210425 \ -# $(MODELS_DST)/eynollah-main-regions-aug-rotation_20210425 \ -# $(MODELS_DST)/eynollah-main-regions-aug-scaling_20210425 \ -# $(MODELS_DST)/eynollah-main-regions-ensembled_20210425 \ -# $(MODELS_DST)/eynollah-main-regions_20220314 \ -# $(MODELS_DST)/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18 \ -# $(MODELS_DST)/eynollah-tables_20210319 \ -# $(MODELS_DST)/model_eynollah_ocr_cnnrnn_20250930 \ +# eynollah-main-regions-aug-rotation_20210425 +# eynollah-main-regions-aug-scaling_20210425 +# eynollah-main-regions-ensembled_20210425 +# eynollah-main-regions_20220314 +# eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18 +# eynollah-tables_20210319 -RELOADABLE_MODELS = \ - $(MODELS_DST)/model_eynollah_page_extraction_20250915 \ - $(MODELS_DST)/model_eynollah_reading_order_20250824 \ - $(MODELS_DST)/modelens_e_l_all_sp_0_1_2_3_4_171024 \ - $(MODELS_DST)/modelens_full_lay_1__4_3_091124 \ - $(MODELS_DST)/modelens_table_0t4_201124 \ - $(MODELS_DST)/modelens_textline_0_1__2_4_16092024 +CURRENT_MODELS := +CURRENT_MODELS += model_eynollah_page_extraction_20250915 +CURRENT_MODELS += model_eynollah_reading_order_20250824 +CURRENT_MODELS += modelens_e_l_all_sp_0_1_2_3_4_171024 +CURRENT_MODELS += modelens_full_lay_1__4_3_091124 +CURRENT_MODELS += modelens_table_0t4_201124 +CURRENT_MODELS += modelens_textline_0_1__2_4_16092024 +CURRENT_MODELS += model_eynollah_ocr_cnnrnn_20250930 +CURRENT_MODELS += eynollah-binarization_20210425 +CURRENT_MODELS += eynollah-column-classifier_20210425 +CURRENT_MODELS += eynollah-enhancement_20210425 -all: $(RELOADABLE_MODELS) +all: tf-serving + +tf-serving: $(CURRENT_MODELS:%=$(MODELS_DST)/%) +keras: $(CURRENT_MODELS:%=$(MODELS_DST)/%.keras) +hdf5: $(CURRENT_MODELS:%=$(MODELS_DST)/%.h5) +onnx: $(CURRENT_MODELS:%=$(MODELS_DST)/%.onnx) $(MODELS_DST)/%: $(MODELS_SRC)/% - test -e $&1 | tee $(notdir $<).log + eynollah-training convert \ + $(and $(wildcard $&1 | tee $(notdir $<).tf-serving.log + +$(MODELS_DST)/%.keras: $(MODELS_SRC)/% + eynollah-training convert \ + $(and $(wildcard $&1 | tee $(notdir $<).keras.log + +$(MODELS_DST)/%.h5: $(MODELS_SRC)/% + eynollah-training convert \ + $(and $(wildcard $&1 | tee $(notdir $<).hdf5.log + +$(MODELS_DST)/%.onnx: $(MODELS_SRC)/% + if jq -e '.task == "segmentation" and .backbone_type == "transformer"' $/dev/null; then \ + echo skipping $@: vision transformer architecture currently does not work with ONNX; else \ + eynollah-training convert \ + $(and $(wildcard $&1 | tee $(notdir $<).onnx.log; fi compare: for i in `find $(MODELS_DST) -mindepth 2`;do \ @@ -44,6 +70,5 @@ compare: du -bs $$n $$i ; \ done - clear: rm -rf $(MODELS_DST) diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index f4cf08b..62d8e51 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -2,6 +2,7 @@ import os import sys import io import json +import click from tqdm import tqdm import requests @@ -791,3 +792,23 @@ def run(_config, model_dir = os.path.join(dir_out,'model_best') model.save(model_dir) ''' + +@click.command(context_settings=dict( + ignore_unknown_options=True, +)) +@click.argument('SACRED_ARGS', nargs=-1, type=click.UNPROCESSED) +def train_cli(sacred_args): + """ + train model on extracted GT + + SACRED_ARGS as per CLI interface of Sacred, cf. + https://sacred.readthedocs.io/en/stable/command_line.html: + + \b + To configure the learning task, pass the string `with`, + followed by any number of + - config JSON file paths + - parameter overrides in the form of key=value + (where the later settings will override the former). + """ + ex.run_commandline([sys.argv[0]] + list(sacred_args)) diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py index e3ede24..f651c56 100644 --- a/src/eynollah/training/weights_ensembling.py +++ b/src/eynollah/training/weights_ensembling.py @@ -43,6 +43,7 @@ def run_ensembling(model_dirs, out_dir): @click.option( "--in", "-i", + "in_", help="input directory of checkpoint models to be read", multiple=True, required=True,