training follow-up:

- use relative imports
- use tf.keras everywhere (and ensure v2)
- `weights_ensembling`:
  * use `Patches` and `PatchEncoder` from .models
  * drop TF1 stuff
  * make function / CLI more flexible (expect list of
    checkpoint dirs instead of single top-level directory)
- train for `classification`: delegate to `weights_ensembling.run_ensembling`
This commit is contained in:
Robert Sachunsky 2026-02-07 16:34:55 +01:00
parent 27f43c175f
commit bd282a594d
6 changed files with 112 additions and 183 deletions

View file

@ -1,6 +1,9 @@
""" """
Load libraries with possible race conditions once. This must be imported as the first module of eynollah. Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
""" """
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs from ocrd_utils import tf_disable_interactive_logs
from torch import * from torch import *
tf_disable_interactive_logs() tf_disable_interactive_logs()

View file

@ -9,7 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli from .inference import main as inference_cli
from .train import ex from .train import ex
from .extract_line_gt import linegt_cli from .extract_line_gt import linegt_cli
from .weights_ensembling import main as ensemble_cli from .weights_ensembling import ensemble_cli
@click.command(context_settings=dict( @click.command(context_settings=dict(
ignore_unknown_options=True, ignore_unknown_options=True,

View file

@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
import cv2 import cv2
import numpy as np import numpy as np
from eynollah.training.gt_gen_utils import ( from .gt_gen_utils import (
filter_contours_area_of_image, filter_contours_area_of_image,
find_format_of_given_filename_in_dir, find_format_of_given_filename_in_dir,
find_new_features_of_contours, find_new_features_of_contours,
@ -26,6 +26,9 @@ from eynollah.training.gt_gen_utils import (
@click.group() @click.group()
def main(): def main():
"""
extract GT data suitable for model training for various tasks
"""
pass pass
@main.command() @main.command()
@ -74,6 +77,9 @@ def main():
) )
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images):
"""
extract PAGE-XML GT data suitable for model training for segmentation tasks
"""
if config: if config:
with open(config) as f: with open(config) as f:
config_params = json.load(f) config_params = json.load(f)
@ -110,6 +116,9 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di
type=click.Path(exists=True, dir_okay=False), type=click.Path(exists=True, dir_okay=False),
) )
def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
"""
extract image GT data suitable for model training for image enhancement tasks
"""
ls_imgs = os.listdir(dir_imgs) ls_imgs = os.listdir(dir_imgs)
with open(scales) as f: with open(scales) as f:
scale_dict = json.load(f) scale_dict = json.load(f)
@ -175,6 +184,9 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
) )
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early): def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early):
"""
extract PAGE-XML GT data suitable for model training for reading-order task
"""
xml_files_ind = os.listdir(dir_xml) xml_files_ind = os.listdir(dir_xml)
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
input_height = int(input_height) input_height = int(input_height)

View file

@ -33,9 +33,9 @@ from .metrics import (
soft_dice_loss, soft_dice_loss,
weighted_categorical_crossentropy, weighted_categorical_crossentropy,
) )
from.utils import scale_padd_image_for_ocr
from ..utils.utils_ocr import decode_batch_predictions
from.utils import (scale_padd_image_for_ocr)
from eynollah.utils.utils_ocr import (decode_batch_predictions)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")

View file

@ -3,30 +3,6 @@ import sys
import json import json
import requests import requests
import click
from eynollah.training.metrics import (
soft_dice_loss,
weighted_categorical_crossentropy
)
from eynollah.training.models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
)
from eynollah.training.utils import (
data_gen,
generate_arrays_from_folder_reading_order,
get_one_hot,
preprocess_imgs,
)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
@ -43,6 +19,31 @@ from sacred.config import create_captured_function
import numpy as np import numpy as np
import cv2 import cv2
from .metrics import (
soft_dice_loss,
weighted_categorical_crossentropy
)
from .models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
)
from .utils import (
data_gen,
generate_arrays_from_folder_reading_order,
get_one_hot,
preprocess_imgs,
)
from .weights_ensembling import run_ensembling
class SaveWeightsAfterSteps(ModelCheckpoint): class SaveWeightsAfterSteps(ModelCheckpoint):
def __init__(self, save_interval, save_path, _config, **kwargs): def __init__(self, save_interval, save_path, _config, **kwargs):
if save_interval: if save_interval:
@ -66,8 +67,6 @@ class SaveWeightsAfterSteps(ModelCheckpoint):
with open(os.path.join(filepath, "config.json"), "w") as fp: with open(os.path.join(filepath, "config.json"), "w") as fp:
json.dump(self._config, fp) # encode dict into JSON json.dump(self._config, fp) # encode dict into JSON
def configuration(): def configuration():
try: try:
for device in tf.config.list_physical_devices('GPU'): for device in tf.config.list_physical_devices('GPU'):
@ -272,6 +271,9 @@ def run(_config,
skewing_amplitudes=None, skewing_amplitudes=None,
max_len=None, max_len=None,
): ):
"""
run configured experiment via sacred
"""
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH): if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
_log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH) _log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH)
@ -586,27 +588,11 @@ def run(_config,
f1_threshold_classification) f1_threshold_classification)
if len(usable_checkpoints) >= 1: if len(usable_checkpoints) >= 1:
_log.info("averaging over usable checkpoints: %s", str(usable_checkpoints)) _log.info("averaging over usable checkpoints: %s", str(usable_checkpoints))
all_weights = [] usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
for epoch in usable_checkpoints: for epoch in usable_checkpoints]
cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) ens_path = os.path.join(dir_output, 'model_ens_avg')
assert os.path.isdir(cp_path), cp_path run_ensembling(usable_checkpoints, ens_path)
model = load_model(cp_path, compile=False) _log.info("ensemble model saved under '%s'", ens_path)
all_weights.append(model.get_weights())
new_weights = []
for layer_weights in zip(*all_weights):
layer_weights = np.array([np.array(weights).mean(axis=0)
for weights in zip(*layer_weights)])
new_weights.append(layer_weights)
#model = tf.keras.models.clone_model(model)
model.set_weights(new_weights)
cp_path = os.path.join(dir_output, 'model_ens_avg')
model.save(cp_path)
with open(os.path.join(cp_path, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("ensemble model saved under '%s'", cp_path)
elif task=='reading_order': elif task=='reading_order':
if continue_training: if continue_training:
@ -657,5 +643,3 @@ def run(_config,
model_dir = os.path.join(dir_out,'model_best') model_dir = os.path.join(dir_out,'model_best')
model.save(model_dir) model.save(model_dir)
''' '''

View file

@ -1,136 +1,66 @@
import sys
from glob import glob
from os import environ, devnull
from os.path import join
from warnings import catch_warnings, simplefilter
import os import os
from warnings import catch_warnings, simplefilter
import click
import numpy as np import numpy as np
from PIL import Image
import cv2 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(devnull, 'w') from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend
sys.stderr = stderr
from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
import click
import logging
from .models import (
PatchEncoder,
Patches,
)
class Patches(layers.Layer): def run_ensembling(model_dirs, out_dir):
def __init__(self, patch_size_x, patch_size_y): all_weights = []
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images): for model_dir in model_dirs:
#print(tf.shape(images)[1],'images') assert os.path.isdir(model_dir), model_dir
#print(self.patch_size,'self.patch_size') model = load_model(model_dir, compile=False,
batch_size = tf.shape(images)[0] custom_objects=dict(PatchEncoder=PatchEncoder,
patches = tf.image.extract_patches( Patches=Patches))
images=images, all_weights.append(model.get_weights())
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy() new_weights = []
config.update({ for layer_weights in zip(*all_weights):
'patch_size_x': self.patch_size_x, layer_weights = np.array([np.array(weights).mean(axis=0)
'patch_size_y': self.patch_size_y, for weights in zip(*layer_weights)])
}) new_weights.append(layer_weights)
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def start_new_session():
###config = tf.compat.v1.ConfigProto()
###config.gpu_options.allow_growth = True
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
###tensorflow_backend.set_session(self.session)
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
return session
def run_ensembling(dir_models, out):
ls_models = os.listdir(dir_models)
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
#model = tf.keras.models.clone_model(model)
model.set_weights(new_weights) model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) model.save(out_dir)
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
@click.command() @click.command()
@click.option( @click.option(
"--dir_models", "--in",
"-dm", "-i",
help="directory of models", help="input directory of checkpoint models to be read",
multiple=True,
required=True,
type=click.Path(exists=True, file_okay=False), type=click.Path(exists=True, file_okay=False),
) )
@click.option( @click.option(
"--out", "--out",
"-o", "-o",
help="output directory where ensembled model will be written.", help="output directory where ensembled model will be written.",
required=True,
type=click.Path(exists=False, file_okay=False), type=click.Path(exists=False, file_okay=False),
) )
def ensemble_cli(in_, out):
"""
mix multiple model weights
def main(dir_models, out): Load a sequence of models and mix them into a single ensemble model
run_ensembling(dir_models, out) by averaging their weights. Write the resulting model.
"""
run_ensembling(in_, out)