training follow-up:

- use relative imports
- use tf.keras everywhere (and ensure v2)
- `weights_ensembling`:
  * use `Patches` and `PatchEncoder` from .models
  * drop TF1 stuff
  * make function / CLI more flexible (expect list of
    checkpoint dirs instead of single top-level directory)
- train for `classification`: delegate to `weights_ensembling.run_ensembling`
This commit is contained in:
Robert Sachunsky 2026-02-07 16:34:55 +01:00
parent 27f43c175f
commit bd282a594d
6 changed files with 112 additions and 183 deletions

View file

@ -1,6 +1,9 @@
"""
Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
"""
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs
from torch import *
tf_disable_interactive_logs()

View file

@ -9,7 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli
from .train import ex
from .extract_line_gt import linegt_cli
from .weights_ensembling import main as ensemble_cli
from .weights_ensembling import ensemble_cli
@click.command(context_settings=dict(
ignore_unknown_options=True,

View file

@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
from eynollah.training.gt_gen_utils import (
from .gt_gen_utils import (
filter_contours_area_of_image,
find_format_of_given_filename_in_dir,
find_new_features_of_contours,
@ -26,6 +26,9 @@ from eynollah.training.gt_gen_utils import (
@click.group()
def main():
"""
extract GT data suitable for model training for various tasks
"""
pass
@main.command()
@ -74,6 +77,9 @@ def main():
)
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images):
"""
extract PAGE-XML GT data suitable for model training for segmentation tasks
"""
if config:
with open(config) as f:
config_params = json.load(f)
@ -110,6 +116,9 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di
type=click.Path(exists=True, dir_okay=False),
)
def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
"""
extract image GT data suitable for model training for image enhancement tasks
"""
ls_imgs = os.listdir(dir_imgs)
with open(scales) as f:
scale_dict = json.load(f)
@ -175,6 +184,9 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
)
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early):
"""
extract PAGE-XML GT data suitable for model training for reading-order task
"""
xml_files_ind = os.listdir(dir_xml)
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
input_height = int(input_height)

View file

@ -33,9 +33,9 @@ from .metrics import (
soft_dice_loss,
weighted_categorical_crossentropy,
)
from.utils import scale_padd_image_for_ocr
from ..utils.utils_ocr import decode_batch_predictions
from.utils import (scale_padd_image_for_ocr)
from eynollah.utils.utils_ocr import (decode_batch_predictions)
with warnings.catch_warnings():
warnings.simplefilter("ignore")

View file

@ -3,30 +3,6 @@ import sys
import json
import requests
import click
from eynollah.training.metrics import (
soft_dice_loss,
weighted_categorical_crossentropy
)
from eynollah.training.models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
)
from eynollah.training.utils import (
data_gen,
generate_arrays_from_folder_reading_order,
get_one_hot,
preprocess_imgs,
)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
@ -43,6 +19,31 @@ from sacred.config import create_captured_function
import numpy as np
import cv2
from .metrics import (
soft_dice_loss,
weighted_categorical_crossentropy
)
from .models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
)
from .utils import (
data_gen,
generate_arrays_from_folder_reading_order,
get_one_hot,
preprocess_imgs,
)
from .weights_ensembling import run_ensembling
class SaveWeightsAfterSteps(ModelCheckpoint):
def __init__(self, save_interval, save_path, _config, **kwargs):
if save_interval:
@ -66,8 +67,6 @@ class SaveWeightsAfterSteps(ModelCheckpoint):
with open(os.path.join(filepath, "config.json"), "w") as fp:
json.dump(self._config, fp) # encode dict into JSON
def configuration():
try:
for device in tf.config.list_physical_devices('GPU'):
@ -272,6 +271,9 @@ def run(_config,
skewing_amplitudes=None,
max_len=None,
):
"""
run configured experiment via sacred
"""
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
_log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH)
@ -586,27 +588,11 @@ def run(_config,
f1_threshold_classification)
if len(usable_checkpoints) >= 1:
_log.info("averaging over usable checkpoints: %s", str(usable_checkpoints))
all_weights = []
for epoch in usable_checkpoints:
cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
assert os.path.isdir(cp_path), cp_path
model = load_model(cp_path, compile=False)
all_weights.append(model.get_weights())
new_weights = []
for layer_weights in zip(*all_weights):
layer_weights = np.array([np.array(weights).mean(axis=0)
for weights in zip(*layer_weights)])
new_weights.append(layer_weights)
#model = tf.keras.models.clone_model(model)
model.set_weights(new_weights)
cp_path = os.path.join(dir_output, 'model_ens_avg')
model.save(cp_path)
with open(os.path.join(cp_path, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("ensemble model saved under '%s'", cp_path)
usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
for epoch in usable_checkpoints]
ens_path = os.path.join(dir_output, 'model_ens_avg')
run_ensembling(usable_checkpoints, ens_path)
_log.info("ensemble model saved under '%s'", ens_path)
elif task=='reading_order':
if continue_training:
@ -657,5 +643,3 @@ def run(_config,
model_dir = os.path.join(dir_out,'model_best')
model.save(model_dir)
'''

View file

@ -1,136 +1,66 @@
import sys
from glob import glob
from os import environ, devnull
from os.path import join
from warnings import catch_warnings, simplefilter
import os
from warnings import catch_warnings, simplefilter
import click
import numpy as np
from PIL import Image
import cv2
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(devnull, 'w')
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend
sys.stderr = stderr
from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
import click
import logging
from .models import (
PatchEncoder,
Patches,
)
class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y):
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def run_ensembling(model_dirs, out_dir):
all_weights = []
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
for model_dir in model_dirs:
assert os.path.isdir(model_dir), model_dir
model = load_model(model_dir, compile=False,
custom_objects=dict(PatchEncoder=PatchEncoder,
Patches=Patches))
all_weights.append(model.get_weights())
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def start_new_session():
###config = tf.compat.v1.ConfigProto()
###config.gpu_options.allow_growth = True
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
###tensorflow_backend.set_session(self.session)
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
return session
def run_ensembling(dir_models, out):
ls_models = os.listdir(dir_models)
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
new_weights = []
for layer_weights in zip(*all_weights):
layer_weights = np.array([np.array(weights).mean(axis=0)
for weights in zip(*layer_weights)])
new_weights.append(layer_weights)
#model = tf.keras.models.clone_model(model)
model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
model.save(out_dir)
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
@click.command()
@click.option(
"--dir_models",
"-dm",
help="directory of models",
"--in",
"-i",
help="input directory of checkpoint models to be read",
multiple=True,
required=True,
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--out",
"-o",
help="output directory where ensembled model will be written.",
required=True,
type=click.Path(exists=False, file_okay=False),
)
def ensemble_cli(in_, out):
"""
mix multiple model weights
def main(dir_models, out):
run_ensembling(dir_models, out)
Load a sequence of models and mix them into a single ensemble model
by averaging their weights. Write the resulting model.
"""
run_ensembling(in_, out)