diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py index 6dce7fd..ddde564 100644 --- a/src/eynollah/training/weights_ensembling.py +++ b/src/eynollah/training/weights_ensembling.py @@ -21,6 +21,11 @@ from tensorflow.keras.layers import * import click import logging +from transformers import TrOCRProcessor +from PIL import Image +import torch +from transformers import VisionEncoderDecoderModel + class Patches(layers.Layer): def __init__(self, patch_size_x, patch_size_y): @@ -92,30 +97,45 @@ def start_new_session(): tensorflow_backend.set_session(session) return session -def run_ensembling(dir_models, out): +def run_ensembling(dir_models, out, framework): ls_models = os.listdir(dir_models) - - - weights=[] - - for model_name in ls_models: - model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) - weights.append(model.get_weights()) + if framework=="torch": + models = [] + sd_models = [] - new_weights = list() - - for weights_list_tuple in zip(*weights): - new_weights.append( - [np.array(weights_).mean(axis=0)\ - for weights_ in zip(*weights_list_tuple)]) + for model_name in ls_models: + model = VisionEncoderDecoderModel.from_pretrained(os.path.join(dir_models,model_name)) + models.append(model) + sd_models.append(model.state_dict()) + for key in sd_models[0]: + sd_models[0][key] = sum(sd[key] for sd in sd_models) / len(sd_models) + + model.load_state_dict(sd_models[0]) + os.system("mkdir "+out) + torch.save(model.state_dict(), os.path.join(out, "pytorch_model.bin")) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) + else: + weights=[] + + for model_name in ls_models: + model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) + weights.append(model.get_weights()) + + new_weights = list() + + for weights_list_tuple in zip(*weights): + new_weights.append( + [np.array(weights_).mean(axis=0)\ + for weights_ in zip(*weights_list_tuple)]) + - new_weights = [np.array(x) for x in new_weights] - - model.set_weights(new_weights) - model.save(out) - os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) + new_weights = [np.array(x) for x in new_weights] + + model.set_weights(new_weights) + model.save(out) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) @click.command() @click.option( @@ -130,7 +150,12 @@ def run_ensembling(dir_models, out): help="output directory where ensembled model will be written.", type=click.Path(exists=False, file_okay=False), ) +@click.option( + "--framework", + "-fw", + help="this parameter gets tensorflow or torch as model framework", +) -def main(dir_models, out): - run_ensembling(dir_models, out) +def main(dir_models, out, framework): + run_ensembling(dir_models, out, framework)