diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index da2cbdb..93b1588 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -825,7 +825,7 @@ def run(_config, usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) for epoch in usable_checkpoints] ens_path = os.path.join(dir_output, 'model_ens_avg') - run_ensembling(usable_checkpoints, ens_path) + run_ensembling(usable_checkpoints, ens_path, framework='tensorflow') _log.info("ensemble model saved under '%s'", ens_path) elif task=='reading_order': diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py index f651c56..1c175c0 100644 --- a/src/eynollah/training/weights_ensembling.py +++ b/src/eynollah/training/weights_ensembling.py @@ -1,4 +1,5 @@ import os +from typing import Optional from warnings import catch_warnings, simplefilter import click @@ -11,33 +12,56 @@ from ocrd_utils import tf_disable_interactive_logs tf_disable_interactive_logs() import tensorflow as tf from tensorflow.keras.models import load_model +import torch +from transformers import VisionEncoderDecoderModel from ..patch_encoder import ( PatchEncoder, Patches, ) -def run_ensembling(model_dirs, out_dir): - all_weights = [] - - for model_dir in model_dirs: - assert os.path.isdir(model_dir), model_dir - model = load_model(model_dir, compile=False, - custom_objects=dict(PatchEncoder=PatchEncoder, - Patches=Patches)) - all_weights.append(model.get_weights()) +def run_ensembling(dir_models, out, framework): + ls_models = os.listdir(dir_models) + # model: Optional[VisionEncoderDecoderModel] = None + # model_name: Optional[str] = None + if framework=="torch": + models = [] + sd_models = [] - new_weights = [] - for layer_weights in zip(*all_weights): - layer_weights = np.array([np.array(weights).mean(axis=0) - for weights in zip(*layer_weights)]) - new_weights.append(layer_weights) + for model_name in ls_models: + model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models, model_name)) + models.append(model) + sd_models.append(model.state_dict()) + for key in sd_models[0]: + sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models) + + model.load_state_dict(sd_models[0]) + os.system("mkdir "+out) + torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin")) + os.system('cp ' + os.path.join(os.path.join(dir_models, model_name), "config.json") + " " + out) + + else: + weights=[] - #model = tf.keras.models.clone_model(model) - model.set_weights(new_weights) + for model_name in ls_models: + model = load_model(os.path.join(dir_models, model_name), compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) + weights.append(model.get_weights()) + + new_weights = list() - model.save(out_dir) - os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/") + for weights_list_tuple in zip(*weights): + new_weights.append( + [np.array(weights_).mean(axis=0)\ + for weights_ in zip(*weights_list_tuple)]) + + + + new_weights = [np.array(x) for x in new_weights] + + model.set_weights(new_weights) + model.save(out) + os.system('cp '+os.path.join(os.path.join(dir_models, model_name), "config.json") + " " + out) + os.system('cp '+os.path.join(os.path.join(dir_models, model_name), "characters_org.txt") + " " + out) @click.command() @click.option( @@ -56,12 +80,19 @@ def run_ensembling(model_dirs, out_dir): required=True, type=click.Path(exists=False, file_okay=False), ) -def ensemble_cli(in_, out): +@click.option( + "--framework", + "-fw", + help="this parameter gets tensorflow or torch as model framework", + type=click.Choice(['torch', 'tensorflow']), + default="tensorflow" +) + +def ensemble_cli(in_, out, framework): """ mix multiple model weights Load a sequence of models and mix them into a single ensemble model by averaging their weights. Write the resulting model. """ - run_ensembling(in_, out) - + run_ensembling(in_, out, framework)