From 3500167870fa7963e291857031bcab9df0c7fb5c Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Wed, 28 Jan 2026 11:52:12 +0100 Subject: [PATCH] weights ensembling for tensorflow models is integrated --- src/eynollah/training/cli.py | 2 + src/eynollah/training/weights_ensembling.py | 136 ++++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 src/eynollah/training/weights_ensembling.py diff --git a/src/eynollah/training/cli.py b/src/eynollah/training/cli.py index 65a7a8a..3718275 100644 --- a/src/eynollah/training/cli.py +++ b/src/eynollah/training/cli.py @@ -9,6 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli from .inference import main as inference_cli from .train import ex from .extract_line_gt import linegt_cli +from .weights_ensembling import main as ensemble_cli @click.command(context_settings=dict( ignore_unknown_options=True, @@ -26,3 +27,4 @@ main.add_command(generate_gt_cli, 'generate-gt') main.add_command(inference_cli, 'inference') main.add_command(train_cli, 'train') main.add_command(linegt_cli, 'export_textline_images_and_text') +main.add_command(ensemble_cli, 'ensembling') diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py new file mode 100644 index 0000000..6dce7fd --- /dev/null +++ b/src/eynollah/training/weights_ensembling.py @@ -0,0 +1,136 @@ +import sys +from glob import glob +from os import environ, devnull +from os.path import join +from warnings import catch_warnings, simplefilter +import os + +import numpy as np +from PIL import Image +import cv2 +environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +stderr = sys.stderr +sys.stderr = open(devnull, 'w') +import tensorflow as tf +from tensorflow.keras.models import load_model +from tensorflow.python.keras import backend as tensorflow_backend +sys.stderr = stderr +from tensorflow.keras import layers +import tensorflow.keras.losses +from tensorflow.keras.layers import * +import click +import logging + + +class Patches(layers.Layer): + def __init__(self, patch_size_x, patch_size_y): + super(Patches, self).__init__() + self.patch_size_x = patch_size_x + self.patch_size_y = patch_size_y + + def call(self, images): + #print(tf.shape(images)[1],'images') + #print(self.patch_size,'self.patch_size') + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size_y, self.patch_size_x, 1], + strides=[1, self.patch_size_y, self.patch_size_x, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + #patch_dims = patches.shape[-1] + patch_dims = tf.shape(patches)[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'patch_size_x': self.patch_size_x, + 'patch_size_y': self.patch_size_y, + }) + return config + + + +class PatchEncoder(layers.Layer): + def __init__(self, **kwargs): + super(PatchEncoder, self).__init__() + self.num_patches = num_patches + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) + + def call(self, patch): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + def get_config(self): + + config = super().get_config().copy() + config.update({ + 'num_patches': self.num_patches, + 'projection': self.projection, + 'position_embedding': self.position_embedding, + }) + return config + + +def start_new_session(): + ###config = tf.compat.v1.ConfigProto() + ###config.gpu_options.allow_growth = True + + ###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + ###tensorflow_backend.set_session(self.session) + + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth = True + + session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() + tensorflow_backend.set_session(session) + return session + +def run_ensembling(dir_models, out): + ls_models = os.listdir(dir_models) + + + weights=[] + + for model_name in ls_models: + model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) + weights.append(model.get_weights()) + + new_weights = list() + + for weights_list_tuple in zip(*weights): + new_weights.append( + [np.array(weights_).mean(axis=0)\ + for weights_ in zip(*weights_list_tuple)]) + + + + new_weights = [np.array(x) for x in new_weights] + + model.set_weights(new_weights) + model.save(out) + os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) + +@click.command() +@click.option( + "--dir_models", + "-dm", + help="directory of models", + type=click.Path(exists=True, file_okay=False), +) +@click.option( + "--out", + "-o", + help="output directory where ensembled model will be written.", + type=click.Path(exists=False, file_okay=False), +) + +def main(dir_models, out): + run_ensembling(dir_models, out) +