training follow-up:

- use relative imports
- use tf.keras everywhere (and ensure v2)
- `weights_ensembling`:
  * use `Patches` and `PatchEncoder` from .models
  * drop TF1 stuff
  * make function / CLI more flexible (expect list of
    checkpoint dirs instead of single top-level directory)
- train for `classification`: delegate to `weights_ensembling.run_ensembling`
This commit is contained in:
Robert Sachunsky 2026-02-07 16:34:55 +01:00
parent 27f43c175f
commit bd282a594d
6 changed files with 112 additions and 183 deletions

View file

@ -1,6 +1,9 @@
""" """
Load libraries with possible race conditions once. This must be imported as the first module of eynollah. Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
""" """
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
from ocrd_utils import tf_disable_interactive_logs from ocrd_utils import tf_disable_interactive_logs
from torch import * from torch import *
tf_disable_interactive_logs() tf_disable_interactive_logs()

View file

@ -9,7 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli from .inference import main as inference_cli
from .train import ex from .train import ex
from .extract_line_gt import linegt_cli from .extract_line_gt import linegt_cli
from .weights_ensembling import main as ensemble_cli from .weights_ensembling import ensemble_cli
@click.command(context_settings=dict( @click.command(context_settings=dict(
ignore_unknown_options=True, ignore_unknown_options=True,

View file

@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
import cv2 import cv2
import numpy as np import numpy as np
from eynollah.training.gt_gen_utils import ( from .gt_gen_utils import (
filter_contours_area_of_image, filter_contours_area_of_image,
find_format_of_given_filename_in_dir, find_format_of_given_filename_in_dir,
find_new_features_of_contours, find_new_features_of_contours,
@ -26,6 +26,9 @@ from eynollah.training.gt_gen_utils import (
@click.group() @click.group()
def main(): def main():
"""
extract GT data suitable for model training for various tasks
"""
pass pass
@main.command() @main.command()
@ -74,6 +77,9 @@ def main():
) )
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images): def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images):
"""
extract PAGE-XML GT data suitable for model training for segmentation tasks
"""
if config: if config:
with open(config) as f: with open(config) as f:
config_params = json.load(f) config_params = json.load(f)
@ -110,6 +116,9 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di
type=click.Path(exists=True, dir_okay=False), type=click.Path(exists=True, dir_okay=False),
) )
def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales): def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
"""
extract image GT data suitable for model training for image enhancement tasks
"""
ls_imgs = os.listdir(dir_imgs) ls_imgs = os.listdir(dir_imgs)
with open(scales) as f: with open(scales) as f:
scale_dict = json.load(f) scale_dict = json.load(f)
@ -175,6 +184,9 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
) )
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early): def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early):
"""
extract PAGE-XML GT data suitable for model training for reading-order task
"""
xml_files_ind = os.listdir(dir_xml) xml_files_ind = os.listdir(dir_xml)
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')] xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
input_height = int(input_height) input_height = int(input_height)

View file

@ -33,9 +33,9 @@ from .metrics import (
soft_dice_loss, soft_dice_loss,
weighted_categorical_crossentropy, weighted_categorical_crossentropy,
) )
from.utils import scale_padd_image_for_ocr
from ..utils.utils_ocr import decode_batch_predictions
from.utils import (scale_padd_image_for_ocr)
from eynollah.utils.utils_ocr import (decode_batch_predictions)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")

View file

@ -3,32 +3,8 @@ import sys
import json import json
import requests import requests
import click
from eynollah.training.metrics import ( os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
soft_dice_loss,
weighted_categorical_crossentropy
)
from eynollah.training.models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
)
from eynollah.training.utils import (
data_gen,
generate_arrays_from_folder_reading_order,
get_one_hot,
preprocess_imgs,
)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.optimizers import SGD, Adam from tensorflow.keras.optimizers import SGD, Adam
@ -43,6 +19,31 @@ from sacred.config import create_captured_function
import numpy as np import numpy as np
import cv2 import cv2
from .metrics import (
soft_dice_loss,
weighted_categorical_crossentropy
)
from .models import (
PatchEncoder,
Patches,
machine_based_reading_order_model,
resnet50_classifier,
resnet50_unet,
vit_resnet50_unet,
vit_resnet50_unet_transformer_before_cnn,
cnn_rnn_ocr_model,
RESNET50_WEIGHTS_PATH,
RESNET50_WEIGHTS_URL
)
from .utils import (
data_gen,
generate_arrays_from_folder_reading_order,
get_one_hot,
preprocess_imgs,
)
from .weights_ensembling import run_ensembling
class SaveWeightsAfterSteps(ModelCheckpoint): class SaveWeightsAfterSteps(ModelCheckpoint):
def __init__(self, save_interval, save_path, _config, **kwargs): def __init__(self, save_interval, save_path, _config, **kwargs):
if save_interval: if save_interval:
@ -65,9 +66,7 @@ class SaveWeightsAfterSteps(ModelCheckpoint):
super()._save_handler(filepath) super()._save_handler(filepath)
with open(os.path.join(filepath, "config.json"), "w") as fp: with open(os.path.join(filepath, "config.json"), "w") as fp:
json.dump(self._config, fp) # encode dict into JSON json.dump(self._config, fp) # encode dict into JSON
def configuration(): def configuration():
try: try:
for device in tf.config.list_physical_devices('GPU'): for device in tf.config.list_physical_devices('GPU'):
@ -272,6 +271,9 @@ def run(_config,
skewing_amplitudes=None, skewing_amplitudes=None,
max_len=None, max_len=None,
): ):
"""
run configured experiment via sacred
"""
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH): if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
_log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH) _log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH)
@ -312,7 +314,7 @@ def run(_config,
imgs_list = list(os.listdir(dir_img)) imgs_list = list(os.listdir(dir_img))
segs_list = list(os.listdir(dir_seg)) segs_list = list(os.listdir(dir_seg))
imgs_list_test = list(os.listdir(dir_img_val)) imgs_list_test = list(os.listdir(dir_img_val))
segs_list_test = list(os.listdir(dir_seg_val)) segs_list_test = list(os.listdir(dir_seg_val))
@ -380,7 +382,7 @@ def run(_config,
num_patches_x = transformer_num_patches_xy[0] num_patches_x = transformer_num_patches_xy[0]
num_patches_y = transformer_num_patches_xy[1] num_patches_y = transformer_num_patches_xy[1]
num_patches = num_patches_x * num_patches_y num_patches = num_patches_x * num_patches_y
if transformer_cnn_first: if transformer_cnn_first:
model_builder = vit_resnet50_unet model_builder = vit_resnet50_unet
multiple_of_32 = True multiple_of_32 = True
@ -413,13 +415,13 @@ def run(_config,
model_builder.config = _config model_builder.config = _config
model_builder.logger = _log model_builder.logger = _log
model = model_builder(num_patches) model = model_builder(num_patches)
assert model is not None assert model is not None
#if you want to see the model structure just uncomment model summary. #if you want to see the model structure just uncomment model summary.
#model.summary() #model.summary()
if task in ["segmentation", "binarization"]: if task in ["segmentation", "binarization"]:
if is_loss_soft_dice: if is_loss_soft_dice:
loss = soft_dice_loss loss = soft_dice_loss
elif weighted_loss: elif weighted_loss:
loss = weighted_categorical_crossentropy(weights) loss = weighted_categorical_crossentropy(weights)
@ -434,7 +436,7 @@ def run(_config,
ignore_class=0, ignore_class=0,
sparse_y_true=False, sparse_y_true=False,
sparse_y_pred=False)]) sparse_y_pred=False)])
# generating train and evaluation data # generating train and evaluation data
gen_kwargs = dict(batch_size=n_batch, gen_kwargs = dict(batch_size=n_batch,
input_height=input_height, input_height=input_height,
@ -447,7 +449,7 @@ def run(_config,
##img_validation_patches = os.listdir(dir_flow_eval_imgs) ##img_validation_patches = os.listdir(dir_flow_eval_imgs)
##score_best=[] ##score_best=[]
##score_best.append(0) ##score_best.append(0)
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config)] SaveWeightsAfterSteps(0, dir_output, _config)]
if save_interval: if save_interval:
@ -471,7 +473,7 @@ def run(_config,
#os.system('rm -rf '+dir_eval_flowing) #os.system('rm -rf '+dir_eval_flowing)
#model.save(dir_output+'/'+'model'+'.h5') #model.save(dir_output+'/'+'model'+'.h5')
elif task=="cnn-rnn-ocr": elif task=="cnn-rnn-ocr":
dir_img, dir_lab = get_dirs_or_files(dir_train) dir_img, dir_lab = get_dirs_or_files(dir_train)
@ -480,7 +482,7 @@ def run(_config,
labs_list = list(os.listdir(dir_lab)) labs_list = list(os.listdir(dir_lab))
imgs_list_val = list(os.listdir(dir_img_val)) imgs_list_val = list(os.listdir(dir_img_val))
labs_list_val = list(os.listdir(dir_lab_val)) labs_list_val = list(os.listdir(dir_lab_val))
with open(characters_txt_file, 'r') as char_txt_f: with open(characters_txt_file, 'r') as char_txt_f:
characters = json.load(char_txt_f) characters = json.load(char_txt_f)
padding_token = len(characters) + 5 padding_token = len(characters) + 5
@ -533,7 +535,7 @@ def run(_config,
#tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha) #tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha)
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config)] SaveWeightsAfterSteps(0, dir_output, _config)]
if save_interval: if save_interval:
@ -544,7 +546,7 @@ def run(_config,
epochs=n_epochs, epochs=n_epochs,
callbacks=callbacks, callbacks=callbacks,
initial_epoch=index_start) initial_epoch=index_start)
elif task=='classification': elif task=='classification':
if continue_training: if continue_training:
model = load_model(dir_of_start_model, compile=False) model = load_model(dir_of_start_model, compile=False)
@ -573,7 +575,7 @@ def run(_config,
monitor='val_f1', monitor='val_f1',
#save_best_only=True, # we need all for ensembling #save_best_only=True, # we need all for ensembling
mode='max')] mode='max')]
history = model.fit(trainXY, history = model.fit(trainXY,
#class_weight=weights) #class_weight=weights)
validation_data=testXY, validation_data=testXY,
@ -586,28 +588,12 @@ def run(_config,
f1_threshold_classification) f1_threshold_classification)
if len(usable_checkpoints) >= 1: if len(usable_checkpoints) >= 1:
_log.info("averaging over usable checkpoints: %s", str(usable_checkpoints)) _log.info("averaging over usable checkpoints: %s", str(usable_checkpoints))
all_weights = [] usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
for epoch in usable_checkpoints: for epoch in usable_checkpoints]
cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1)) ens_path = os.path.join(dir_output, 'model_ens_avg')
assert os.path.isdir(cp_path), cp_path run_ensembling(usable_checkpoints, ens_path)
model = load_model(cp_path, compile=False) _log.info("ensemble model saved under '%s'", ens_path)
all_weights.append(model.get_weights())
new_weights = []
for layer_weights in zip(*all_weights):
layer_weights = np.array([np.array(weights).mean(axis=0)
for weights in zip(*layer_weights)])
new_weights.append(layer_weights)
#model = tf.keras.models.clone_model(model)
model.set_weights(new_weights)
cp_path = os.path.join(dir_output, 'model_ens_avg')
model.save(cp_path)
with open(os.path.join(cp_path, "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON
_log.info("ensemble model saved under '%s'", cp_path)
elif task=='reading_order': elif task=='reading_order':
if continue_training: if continue_training:
model = load_model(dir_of_start_model, compile=False) model = load_model(dir_of_start_model, compile=False)
@ -618,10 +604,10 @@ def run(_config,
input_width, input_width,
weight_decay, weight_decay,
pretraining) pretraining)
dir_flow_train_imgs = os.path.join(dir_train, 'images') dir_flow_train_imgs = os.path.join(dir_train, 'images')
dir_flow_train_labels = os.path.join(dir_train, 'labels') dir_flow_train_labels = os.path.join(dir_train, 'labels')
classes = os.listdir(dir_flow_train_labels) classes = os.listdir(dir_flow_train_labels)
if augmentation: if augmentation:
num_rows = len(classes)*(len(thetha) + 1) num_rows = len(classes)*(len(thetha) + 1)
@ -634,7 +620,7 @@ def run(_config,
#optimizer=SGD(learning_rate=0.01, momentum=0.9), #optimizer=SGD(learning_rate=0.01, momentum=0.9),
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate? optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
metrics=['accuracy']) metrics=['accuracy'])
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config)] SaveWeightsAfterSteps(0, dir_output, _config)]
if save_interval: if save_interval:
@ -657,5 +643,3 @@ def run(_config,
model_dir = os.path.join(dir_out,'model_best') model_dir = os.path.join(dir_out,'model_best')
model.save(model_dir) model.save(model_dir)
''' '''

View file

@ -1,136 +1,66 @@
import sys
from glob import glob
from os import environ, devnull
from os.path import join
from warnings import catch_warnings, simplefilter
import os import os
from warnings import catch_warnings, simplefilter
import click
import numpy as np import numpy as np
from PIL import Image
import cv2 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(devnull, 'w') from ocrd_utils import tf_disable_interactive_logs
tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend
sys.stderr = stderr
from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
import click
import logging
from .models import (
class Patches(layers.Layer): PatchEncoder,
def __init__(self, patch_size_x, patch_size_y): Patches,
super(Patches, self).__init__() )
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def run_ensembling(model_dirs, out_dir):
def start_new_session(): all_weights = []
###config = tf.compat.v1.ConfigProto()
###config.gpu_options.allow_growth = True
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() for model_dir in model_dirs:
###tensorflow_backend.set_session(self.session) assert os.path.isdir(model_dir), model_dir
model = load_model(model_dir, compile=False,
config = tf.compat.v1.ConfigProto() custom_objects=dict(PatchEncoder=PatchEncoder,
config.gpu_options.allow_growth = True Patches=Patches))
all_weights.append(model.get_weights())
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
return session
def run_ensembling(dir_models, out):
ls_models = os.listdir(dir_models)
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list() new_weights = []
for layer_weights in zip(*all_weights):
layer_weights = np.array([np.array(weights).mean(axis=0)
for weights in zip(*layer_weights)])
new_weights.append(layer_weights)
for weights_list_tuple in zip(*weights): #model = tf.keras.models.clone_model(model)
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights) model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out) model.save(out_dir)
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
@click.command() @click.command()
@click.option( @click.option(
"--dir_models", "--in",
"-dm", "-i",
help="directory of models", help="input directory of checkpoint models to be read",
multiple=True,
required=True,
type=click.Path(exists=True, file_okay=False), type=click.Path(exists=True, file_okay=False),
) )
@click.option( @click.option(
"--out", "--out",
"-o", "-o",
help="output directory where ensembled model will be written.", help="output directory where ensembled model will be written.",
required=True,
type=click.Path(exists=False, file_okay=False), type=click.Path(exists=False, file_okay=False),
) )
def ensemble_cli(in_, out):
"""
mix multiple model weights
Load a sequence of models and mix them into a single ensemble model
by averaging their weights. Write the resulting model.
"""
run_ensembling(in_, out)
def main(dir_models, out):
run_ensembling(dir_models, out)