mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
torch model ensembling is integrated
This commit is contained in:
parent
498ff8f7a5
commit
fbf252db13
1 changed files with 46 additions and 21 deletions
|
|
@ -21,6 +21,11 @@ from tensorflow.keras.layers import *
|
||||||
import click
|
import click
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from transformers import TrOCRProcessor
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from transformers import VisionEncoderDecoderModel
|
||||||
|
|
||||||
|
|
||||||
class Patches(layers.Layer):
|
class Patches(layers.Layer):
|
||||||
def __init__(self, patch_size_x, patch_size_y):
|
def __init__(self, patch_size_x, patch_size_y):
|
||||||
|
|
@ -92,30 +97,45 @@ def start_new_session():
|
||||||
tensorflow_backend.set_session(session)
|
tensorflow_backend.set_session(session)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def run_ensembling(dir_models, out):
|
def run_ensembling(dir_models, out, framework):
|
||||||
ls_models = os.listdir(dir_models)
|
ls_models = os.listdir(dir_models)
|
||||||
|
if framework=="torch":
|
||||||
|
models = []
|
||||||
weights=[]
|
sd_models = []
|
||||||
|
|
||||||
for model_name in ls_models:
|
|
||||||
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
|
||||||
weights.append(model.get_weights())
|
|
||||||
|
|
||||||
new_weights = list()
|
for model_name in ls_models:
|
||||||
|
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
|
||||||
for weights_list_tuple in zip(*weights):
|
models.append(model)
|
||||||
new_weights.append(
|
sd_models.append(model.state_dict())
|
||||||
[np.array(weights_).mean(axis=0)\
|
for key in sd_models[0]:
|
||||||
for weights_ in zip(*weights_list_tuple)])
|
sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
|
||||||
|
|
||||||
|
model.load_state_dict(sd_models[0])
|
||||||
|
os.system("mkdir "+out)
|
||||||
|
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
|
||||||
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
||||||
|
|
||||||
|
else:
|
||||||
|
weights=[]
|
||||||
|
|
||||||
|
for model_name in ls_models:
|
||||||
|
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
||||||
|
weights.append(model.get_weights())
|
||||||
|
|
||||||
|
new_weights = list()
|
||||||
|
|
||||||
|
for weights_list_tuple in zip(*weights):
|
||||||
|
new_weights.append(
|
||||||
|
[np.array(weights_).mean(axis=0)\
|
||||||
|
for weights_ in zip(*weights_list_tuple)])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
new_weights = [np.array(x) for x in new_weights]
|
new_weights = [np.array(x) for x in new_weights]
|
||||||
|
|
||||||
model.set_weights(new_weights)
|
model.set_weights(new_weights)
|
||||||
model.save(out)
|
model.save(out)
|
||||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|
@ -130,7 +150,12 @@ def run_ensembling(dir_models, out):
|
||||||
help="output directory where ensembled model will be written.",
|
help="output directory where ensembled model will be written.",
|
||||||
type=click.Path(exists=False, file_okay=False),
|
type=click.Path(exists=False, file_okay=False),
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--framework",
|
||||||
|
"-fw",
|
||||||
|
help="this parameter gets tensorflow or torch as model framework",
|
||||||
|
)
|
||||||
|
|
||||||
def main(dir_models, out):
|
def main(dir_models, out, framework):
|
||||||
run_ensembling(dir_models, out)
|
run_ensembling(dir_models, out, framework)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue