torch model ensembling is integrated

This commit is contained in:
vahidrezanezhad 2026-02-04 21:16:08 +01:00 committed by kba
parent aba0138216
commit 4776ea9fc4
2 changed files with 53 additions and 22 deletions

View file

@ -825,7 +825,7 @@ def run(_config,
usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
for epoch in usable_checkpoints] for epoch in usable_checkpoints]
ens_path = os.path.join(dir_output, 'model_ens_avg') ens_path = os.path.join(dir_output, 'model_ens_avg')
run_ensembling(usable_checkpoints, ens_path) run_ensembling(usable_checkpoints, ens_path, framework='tensorflow')
_log.info("ensemble model saved under '%s'", ens_path) _log.info("ensemble model saved under '%s'", ens_path)
elif task=='reading_order': elif task=='reading_order':

View file

@ -1,4 +1,5 @@
import os import os
from typing import Optional
from warnings import catch_warnings, simplefilter from warnings import catch_warnings, simplefilter
import click import click
@ -11,33 +12,56 @@ from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs() tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
import torch
from transformers import VisionEncoderDecoderModel
from ..patch_encoder import ( from ..patch_encoder import (
PatchEncoder, PatchEncoder,
Patches, Patches,
) )
def run_ensembling(model_dirs, out_dir): def run_ensembling(dir_models, out, framework):
all_weights = [] ls_models = os.listdir(dir_models)
# model: Optional[VisionEncoderDecoderModel] = None
# model_name: Optional[str] = None
if framework=="torch":
models = []
sd_models = []
for model_dir in model_dirs: for model_name in ls_models:
assert os.path.isdir(model_dir), model_dir model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models, model_name))
model = load_model(model_dir, compile=False, models.append(model)
custom_objects=dict(PatchEncoder=PatchEncoder, sd_models.append(model.state_dict())
Patches=Patches)) for key in sd_models[0]:
all_weights.append(model.get_weights()) sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
new_weights = [] model.load_state_dict(sd_models[0])
for layer_weights in zip(*all_weights): os.system("mkdir "+out)
layer_weights = np.array([np.array(weights).mean(axis=0) torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
for weights in zip(*layer_weights)]) os.system('cp ' + os.path.join(os.path.join(dir_models, model_name), "config.json") + " " + out)
new_weights.append(layer_weights)
else:
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models, model_name), compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
#model = tf.keras.models.clone_model(model)
model.set_weights(new_weights) model.set_weights(new_weights)
model.save(out)
model.save(out_dir) os.system('cp '+os.path.join(os.path.join(dir_models, model_name), "config.json") + " " + out)
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/") os.system('cp '+os.path.join(os.path.join(dir_models, model_name), "characters_org.txt") + " " + out)
@click.command() @click.command()
@click.option( @click.option(
@ -56,12 +80,19 @@ def run_ensembling(model_dirs, out_dir):
required=True, required=True,
type=click.Path(exists=False, file_okay=False), type=click.Path(exists=False, file_okay=False),
) )
def ensemble_cli(in_, out): @click.option(
"--framework",
"-fw",
help="this parameter gets tensorflow or torch as model framework",
type=click.Choice(['torch', 'tensorflow']),
default="tensorflow"
)
def ensemble_cli(in_, out, framework):
""" """
mix multiple model weights mix multiple model weights
Load a sequence of models and mix them into a single ensemble model Load a sequence of models and mix them into a single ensemble model
by averaging their weights. Write the resulting model. by averaging their weights. Write the resulting model.
""" """
run_ensembling(in_, out) run_ensembling(in_, out, framework)