From bd282a594d7dac9adcbcce55b09fbd1e1a7f85a9 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sat, 7 Feb 2026 16:34:55 +0100 Subject: [PATCH] training follow-up: - use relative imports - use tf.keras everywhere (and ensure v2) - `weights_ensembling`: * use `Patches` and `PatchEncoder` from .models * drop TF1 stuff * make function / CLI more flexible (expect list of checkpoint dirs instead of single top-level directory) - train for `classification`: delegate to `weights_ensembling.run_ensembling` --- src/eynollah/eynollah_imports.py | 3 + src/eynollah/training/cli.py | 2 +- .../training/generate_gt_for_training.py | 14 +- src/eynollah/training/inference.py | 4 +- src/eynollah/training/train.py | 116 ++++++------- src/eynollah/training/weights_ensembling.py | 156 +++++------------- 6 files changed, 112 insertions(+), 183 deletions(-) diff --git a/src/eynollah/eynollah_imports.py b/src/eynollah/eynollah_imports.py index f04cfdc..496406c 100644 --- a/src/eynollah/eynollah_imports.py +++ b/src/eynollah/eynollah_imports.py @@ -1,6 +1,9 @@ """ Load libraries with possible race conditions once. This must be imported as the first module of eynollah. """ +import os +os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 + from ocrd_utils import tf_disable_interactive_logs from torch import * tf_disable_interactive_logs() diff --git a/src/eynollah/training/cli.py b/src/eynollah/training/cli.py index 3718275..ae14f04 100644 --- a/src/eynollah/training/cli.py +++ b/src/eynollah/training/cli.py @@ -9,7 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli from .inference import main as inference_cli from .train import ex from .extract_line_gt import linegt_cli -from .weights_ensembling import main as ensemble_cli +from .weights_ensembling import ensemble_cli @click.command(context_settings=dict( ignore_unknown_options=True, diff --git a/src/eynollah/training/generate_gt_for_training.py b/src/eynollah/training/generate_gt_for_training.py index 2c076d3..2422cc2 100644 --- a/src/eynollah/training/generate_gt_for_training.py +++ b/src/eynollah/training/generate_gt_for_training.py @@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont import cv2 import numpy as np -from eynollah.training.gt_gen_utils import ( +from .gt_gen_utils import ( filter_contours_area_of_image, find_format_of_given_filename_in_dir, find_new_features_of_contours, @@ -26,6 +26,9 @@ from eynollah.training.gt_gen_utils import ( @click.group() def main(): + """ + extract GT data suitable for model training for various tasks + """ pass @main.command() @@ -74,6 +77,9 @@ def main(): ) def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): + """ + extract PAGE-XML GT data suitable for model training for segmentation tasks + """ if config: with open(config) as f: config_params = json.load(f) @@ -110,6 +116,9 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di type=click.Path(exists=True, dir_okay=False), ) def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): + """ + extract image GT data suitable for model training for image enhancement tasks + """ ls_imgs = os.listdir(dir_imgs) with open(scales) as f: scale_dict = json.load(f) @@ -175,6 +184,9 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): ) def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early): + """ + extract PAGE-XML GT data suitable for model training for reading-order task + """ xml_files_ind = os.listdir(dir_xml) xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] input_height = int(input_height) diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index 454c689..2b26210 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -33,9 +33,9 @@ from .metrics import ( soft_dice_loss, weighted_categorical_crossentropy, ) +from.utils import scale_padd_image_for_ocr +from ..utils.utils_ocr import decode_batch_predictions -from.utils import (scale_padd_image_for_ocr) -from eynollah.utils.utils_ocr import (decode_batch_predictions) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index 61dbdf7..217ab35 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -3,32 +3,8 @@ import sys import json import requests -import click -from eynollah.training.metrics import ( - soft_dice_loss, - weighted_categorical_crossentropy -) -from eynollah.training.models import ( - PatchEncoder, - Patches, - machine_based_reading_order_model, - resnet50_classifier, - resnet50_unet, - vit_resnet50_unet, - vit_resnet50_unet_transformer_before_cnn, - cnn_rnn_ocr_model, - RESNET50_WEIGHTS_PATH, - RESNET50_WEIGHTS_URL -) -from eynollah.training.utils import ( - data_gen, - generate_arrays_from_folder_reading_order, - get_one_hot, - preprocess_imgs, -) - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf from tensorflow.keras.optimizers import SGD, Adam @@ -43,6 +19,31 @@ from sacred.config import create_captured_function import numpy as np import cv2 +from .metrics import ( + soft_dice_loss, + weighted_categorical_crossentropy +) +from .models import ( + PatchEncoder, + Patches, + machine_based_reading_order_model, + resnet50_classifier, + resnet50_unet, + vit_resnet50_unet, + vit_resnet50_unet_transformer_before_cnn, + cnn_rnn_ocr_model, + RESNET50_WEIGHTS_PATH, + RESNET50_WEIGHTS_URL +) +from .utils import ( + data_gen, + generate_arrays_from_folder_reading_order, + get_one_hot, + preprocess_imgs, +) +from .weights_ensembling import run_ensembling + + class SaveWeightsAfterSteps(ModelCheckpoint): def __init__(self, save_interval, save_path, _config, **kwargs): if save_interval: @@ -65,9 +66,7 @@ class SaveWeightsAfterSteps(ModelCheckpoint): super()._save_handler(filepath) with open(os.path.join(filepath, "config.json"), "w") as fp: json.dump(self._config, fp) # encode dict into JSON - - - + def configuration(): try: for device in tf.config.list_physical_devices('GPU'): @@ -272,6 +271,9 @@ def run(_config, skewing_amplitudes=None, max_len=None, ): + """ + run configured experiment via sacred + """ if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH): _log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH) @@ -312,7 +314,7 @@ def run(_config, imgs_list = list(os.listdir(dir_img)) segs_list = list(os.listdir(dir_seg)) - + imgs_list_test = list(os.listdir(dir_img_val)) segs_list_test = list(os.listdir(dir_seg_val)) @@ -380,7 +382,7 @@ def run(_config, num_patches_x = transformer_num_patches_xy[0] num_patches_y = transformer_num_patches_xy[1] num_patches = num_patches_x * num_patches_y - + if transformer_cnn_first: model_builder = vit_resnet50_unet multiple_of_32 = True @@ -413,13 +415,13 @@ def run(_config, model_builder.config = _config model_builder.logger = _log model = model_builder(num_patches) - + assert model is not None #if you want to see the model structure just uncomment model summary. #model.summary() - + if task in ["segmentation", "binarization"]: - if is_loss_soft_dice: + if is_loss_soft_dice: loss = soft_dice_loss elif weighted_loss: loss = weighted_categorical_crossentropy(weights) @@ -434,7 +436,7 @@ def run(_config, ignore_class=0, sparse_y_true=False, sparse_y_pred=False)]) - + # generating train and evaluation data gen_kwargs = dict(batch_size=n_batch, input_height=input_height, @@ -447,7 +449,7 @@ def run(_config, ##img_validation_patches = os.listdir(dir_flow_eval_imgs) ##score_best=[] ##score_best.append(0) - + callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), SaveWeightsAfterSteps(0, dir_output, _config)] if save_interval: @@ -471,7 +473,7 @@ def run(_config, #os.system('rm -rf '+dir_eval_flowing) #model.save(dir_output+'/'+'model'+'.h5') - + elif task=="cnn-rnn-ocr": dir_img, dir_lab = get_dirs_or_files(dir_train) @@ -480,7 +482,7 @@ def run(_config, labs_list = list(os.listdir(dir_lab)) imgs_list_val = list(os.listdir(dir_img_val)) labs_list_val = list(os.listdir(dir_lab_val)) - + with open(characters_txt_file, 'r') as char_txt_f: characters = json.load(char_txt_f) padding_token = len(characters) + 5 @@ -533,7 +535,7 @@ def run(_config, #tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha) opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer - + callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), SaveWeightsAfterSteps(0, dir_output, _config)] if save_interval: @@ -544,7 +546,7 @@ def run(_config, epochs=n_epochs, callbacks=callbacks, initial_epoch=index_start) - + elif task=='classification': if continue_training: model = load_model(dir_of_start_model, compile=False) @@ -573,7 +575,7 @@ def run(_config, monitor='val_f1', #save_best_only=True, # we need all for ensembling mode='max')] - + history = model.fit(trainXY, #class_weight=weights) validation_data=testXY, @@ -586,28 +588,12 @@ def run(_config, f1_threshold_classification) if len(usable_checkpoints) >= 1: _log.info("averaging over usable checkpoints: %s", str(usable_checkpoints)) - all_weights = [] - for epoch in usable_checkpoints: - cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) - assert os.path.isdir(cp_path), cp_path - model = load_model(cp_path, compile=False) - all_weights.append(model.get_weights()) + usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) + for epoch in usable_checkpoints] + ens_path = os.path.join(dir_output, 'model_ens_avg') + run_ensembling(usable_checkpoints, ens_path) + _log.info("ensemble model saved under '%s'", ens_path) - new_weights = [] - for layer_weights in zip(*all_weights): - layer_weights = np.array([np.array(weights).mean(axis=0) - for weights in zip(*layer_weights)]) - new_weights.append(layer_weights) - - #model = tf.keras.models.clone_model(model) - model.set_weights(new_weights) - - cp_path = os.path.join(dir_output, 'model_ens_avg') - model.save(cp_path) - with open(os.path.join(cp_path, "config.json"), "w") as fp: - json.dump(_config, fp) # encode dict into JSON - _log.info("ensemble model saved under '%s'", cp_path) - elif task=='reading_order': if continue_training: model = load_model(dir_of_start_model, compile=False) @@ -618,10 +604,10 @@ def run(_config, input_width, weight_decay, pretraining) - + dir_flow_train_imgs = os.path.join(dir_train, 'images') dir_flow_train_labels = os.path.join(dir_train, 'labels') - + classes = os.listdir(dir_flow_train_labels) if augmentation: num_rows = len(classes)*(len(thetha) + 1) @@ -634,7 +620,7 @@ def run(_config, #optimizer=SGD(learning_rate=0.01, momentum=0.9), optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate? metrics=['accuracy']) - + callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), SaveWeightsAfterSteps(0, dir_output, _config)] if save_interval: @@ -657,5 +643,3 @@ def run(_config, model_dir = os.path.join(dir_out,'model_best') model.save(model_dir) ''' - - diff --git a/src/eynollah/training/weights_ensembling.py b/src/eynollah/training/weights_ensembling.py index 6dce7fd..01532fd 100644 --- a/src/eynollah/training/weights_ensembling.py +++ b/src/eynollah/training/weights_ensembling.py @@ -1,136 +1,66 @@ -import sys -from glob import glob -from os import environ, devnull -from os.path import join -from warnings import catch_warnings, simplefilter import os +from warnings import catch_warnings, simplefilter +import click import numpy as np -from PIL import Image -import cv2 -environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -stderr = sys.stderr -sys.stderr = open(devnull, 'w') + +os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +from ocrd_utils import tf_disable_interactive_logs +tf_disable_interactive_logs() import tensorflow as tf from tensorflow.keras.models import load_model -from tensorflow.python.keras import backend as tensorflow_backend -sys.stderr = stderr -from tensorflow.keras import layers -import tensorflow.keras.losses -from tensorflow.keras.layers import * -import click -import logging - -class Patches(layers.Layer): - def __init__(self, patch_size_x, patch_size_y): - super(Patches, self).__init__() - self.patch_size_x = patch_size_x - self.patch_size_y = patch_size_y - - def call(self, images): - #print(tf.shape(images)[1],'images') - #print(self.patch_size,'self.patch_size') - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=[1, self.patch_size_y, self.patch_size_x, 1], - strides=[1, self.patch_size_y, self.patch_size_x, 1], - rates=[1, 1, 1, 1], - padding="VALID", - ) - #patch_dims = patches.shape[-1] - patch_dims = tf.shape(patches)[-1] - patches = tf.reshape(patches, [batch_size, -1, patch_dims]) - return patches - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'patch_size_x': self.patch_size_x, - 'patch_size_y': self.patch_size_y, - }) - return config - - - -class PatchEncoder(layers.Layer): - def __init__(self, **kwargs): - super(PatchEncoder, self).__init__() - self.num_patches = num_patches - self.projection = layers.Dense(units=projection_dim) - self.position_embedding = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - ) - - def call(self, patch): - positions = tf.range(start=0, limit=self.num_patches, delta=1) - encoded = self.projection(patch) + self.position_embedding(positions) - return encoded - def get_config(self): - - config = super().get_config().copy() - config.update({ - 'num_patches': self.num_patches, - 'projection': self.projection, - 'position_embedding': self.position_embedding, - }) - return config +from .models import ( + PatchEncoder, + Patches, +) - -def start_new_session(): - ###config = tf.compat.v1.ConfigProto() - ###config.gpu_options.allow_growth = True +def run_ensembling(model_dirs, out_dir): + all_weights = [] - ###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() - ###tensorflow_backend.set_session(self.session) - - config = tf.compat.v1.ConfigProto() - config.gpu_options.allow_growth = True - - session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() - tensorflow_backend.set_session(session) - return session - -def run_ensembling(dir_models, out): - ls_models = os.listdir(dir_models) - - - weights=[] - - for model_name in ls_models: - model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches}) - weights.append(model.get_weights()) + for model_dir in model_dirs: + assert os.path.isdir(model_dir), model_dir + model = load_model(model_dir, compile=False, + custom_objects=dict(PatchEncoder=PatchEncoder, + Patches=Patches)) + all_weights.append(model.get_weights()) - new_weights = list() + new_weights = [] + for layer_weights in zip(*all_weights): + layer_weights = np.array([np.array(weights).mean(axis=0) + for weights in zip(*layer_weights)]) + new_weights.append(layer_weights) - for weights_list_tuple in zip(*weights): - new_weights.append( - [np.array(weights_).mean(axis=0)\ - for weights_ in zip(*weights_list_tuple)]) - - - - new_weights = [np.array(x) for x in new_weights] - + #model = tf.keras.models.clone_model(model) model.set_weights(new_weights) - model.save(out) - os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) + + model.save(out_dir) + os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/") @click.command() @click.option( - "--dir_models", - "-dm", - help="directory of models", + "--in", + "-i", + help="input directory of checkpoint models to be read", + multiple=True, + required=True, type=click.Path(exists=True, file_okay=False), ) @click.option( "--out", "-o", help="output directory where ensembled model will be written.", + required=True, type=click.Path(exists=False, file_okay=False), ) +def ensemble_cli(in_, out): + """ + mix multiple model weights + + Load a sequence of models and mix them into a single ensemble model + by averaging their weights. Write the resulting model. + """ + run_ensembling(in_, out) -def main(dir_models, out): - run_ensembling(dir_models, out) -