mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-06-16 09:59:13 +02:00
torch model ensembling is integrated
This commit is contained in:
parent
aba0138216
commit
4776ea9fc4
2 changed files with 53 additions and 22 deletions
|
|
@ -825,7 +825,7 @@ def run(_config,
|
||||||
usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
|
usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
|
||||||
for epoch in usable_checkpoints]
|
for epoch in usable_checkpoints]
|
||||||
ens_path = os.path.join(dir_output, 'model_ens_avg')
|
ens_path = os.path.join(dir_output, 'model_ens_avg')
|
||||||
run_ensembling(usable_checkpoints, ens_path)
|
run_ensembling(usable_checkpoints, ens_path, framework='tensorflow')
|
||||||
_log.info("ensemble model saved under '%s'", ens_path)
|
_log.info("ensemble model saved under '%s'", ens_path)
|
||||||
|
|
||||||
elif task=='reading_order':
|
elif task=='reading_order':
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
from warnings import catch_warnings, simplefilter
|
from warnings import catch_warnings, simplefilter
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
@ -11,33 +12,56 @@ from ocrd_utils import tf_disable_interactive_logs
|
||||||
tf_disable_interactive_logs()
|
tf_disable_interactive_logs()
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
|
import torch
|
||||||
|
from transformers import VisionEncoderDecoderModel
|
||||||
|
|
||||||
from ..patch_encoder import (
|
from ..patch_encoder import (
|
||||||
PatchEncoder,
|
PatchEncoder,
|
||||||
Patches,
|
Patches,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_ensembling(model_dirs, out_dir):
|
def run_ensembling(dir_models, out, framework):
|
||||||
all_weights = []
|
ls_models = os.listdir(dir_models)
|
||||||
|
# model: Optional[VisionEncoderDecoderModel] = None
|
||||||
for model_dir in model_dirs:
|
# model_name: Optional[str] = None
|
||||||
assert os.path.isdir(model_dir), model_dir
|
if framework=="torch":
|
||||||
model = load_model(model_dir, compile=False,
|
models = []
|
||||||
custom_objects=dict(PatchEncoder=PatchEncoder,
|
sd_models = []
|
||||||
Patches=Patches))
|
|
||||||
all_weights.append(model.get_weights())
|
|
||||||
|
|
||||||
new_weights = []
|
for model_name in ls_models:
|
||||||
for layer_weights in zip(*all_weights):
|
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models, model_name))
|
||||||
layer_weights = np.array([np.array(weights).mean(axis=0)
|
models.append(model)
|
||||||
for weights in zip(*layer_weights)])
|
sd_models.append(model.state_dict())
|
||||||
new_weights.append(layer_weights)
|
for key in sd_models[0]:
|
||||||
|
sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
|
||||||
|
|
||||||
|
model.load_state_dict(sd_models[0])
|
||||||
|
os.system("mkdir "+out)
|
||||||
|
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
|
||||||
|
os.system('cp ' + os.path.join(os.path.join(dir_models, model_name), "config.json") + " " + out)
|
||||||
|
|
||||||
|
else:
|
||||||
|
weights=[]
|
||||||
|
|
||||||
#model = tf.keras.models.clone_model(model)
|
for model_name in ls_models:
|
||||||
model.set_weights(new_weights)
|
model = load_model(os.path.join(dir_models, model_name), compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
||||||
|
weights.append(model.get_weights())
|
||||||
|
|
||||||
|
new_weights = list()
|
||||||
|
|
||||||
model.save(out_dir)
|
for weights_list_tuple in zip(*weights):
|
||||||
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
|
new_weights.append(
|
||||||
|
[np.array(weights_).mean(axis=0)\
|
||||||
|
for weights_ in zip(*weights_list_tuple)])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
new_weights = [np.array(x) for x in new_weights]
|
||||||
|
|
||||||
|
model.set_weights(new_weights)
|
||||||
|
model.save(out)
|
||||||
|
os.system('cp '+os.path.join(os.path.join(dir_models, model_name), "config.json") + " " + out)
|
||||||
|
os.system('cp '+os.path.join(os.path.join(dir_models, model_name), "characters_org.txt") + " " + out)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|
@ -56,12 +80,19 @@ def run_ensembling(model_dirs, out_dir):
|
||||||
required=True,
|
required=True,
|
||||||
type=click.Path(exists=False, file_okay=False),
|
type=click.Path(exists=False, file_okay=False),
|
||||||
)
|
)
|
||||||
def ensemble_cli(in_, out):
|
@click.option(
|
||||||
|
"--framework",
|
||||||
|
"-fw",
|
||||||
|
help="this parameter gets tensorflow or torch as model framework",
|
||||||
|
type=click.Choice(['torch', 'tensorflow']),
|
||||||
|
default="tensorflow"
|
||||||
|
)
|
||||||
|
|
||||||
|
def ensemble_cli(in_, out, framework):
|
||||||
"""
|
"""
|
||||||
mix multiple model weights
|
mix multiple model weights
|
||||||
|
|
||||||
Load a sequence of models and mix them into a single ensemble model
|
Load a sequence of models and mix them into a single ensemble model
|
||||||
by averaging their weights. Write the resulting model.
|
by averaging their weights. Write the resulting model.
|
||||||
"""
|
"""
|
||||||
run_ensembling(in_, out)
|
run_ensembling(in_, out, framework)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue