torch model ensembling is integrated

This commit is contained in:
vahidrezanezhad 2026-02-04 21:16:08 +01:00
parent 498ff8f7a5
commit fbf252db13

View file

@ -21,6 +21,11 @@ from tensorflow.keras.layers import *
import click import click
import logging import logging
from transformers import TrOCRProcessor
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel
class Patches(layers.Layer): class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y): def __init__(self, patch_size_x, patch_size_y):
@ -92,30 +97,45 @@ def start_new_session():
tensorflow_backend.set_session(session) tensorflow_backend.set_session(session)
return session return session
def run_ensembling(dir_models, out): def run_ensembling(dir_models, out, framework):
ls_models = os.listdir(dir_models) ls_models = os.listdir(dir_models)
if framework=="torch":
models = []
weights=[] sd_models = []
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list() for model_name in ls_models:
model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name))
for weights_list_tuple in zip(*weights): models.append(model)
new_weights.append( sd_models.append(model.state_dict())
[np.array(weights_).mean(axis=0)\ for key in sd_models[0]:
for weights_ in zip(*weights_list_tuple)]) sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models)
model.load_state_dict(sd_models[0])
os.system("mkdir "+out)
torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin"))
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
else:
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights] new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights) model.set_weights(new_weights)
model.save(out) model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
@click.command() @click.command()
@click.option( @click.option(
@ -130,7 +150,12 @@ def run_ensembling(dir_models, out):
help="output directory where ensembled model will be written.", help="output directory where ensembled model will be written.",
type=click.Path(exists=False, file_okay=False), type=click.Path(exists=False, file_okay=False),
) )
@click.option(
"--framework",
"-fw",
help="this parameter gets tensorflow or torch as model framework",
)
def main(dir_models, out): def main(dir_models, out, framework):
run_ensembling(dir_models, out) run_ensembling(dir_models, out, framework)