mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training follow-up:
- use relative imports
- use tf.keras everywhere (and ensure v2)
- `weights_ensembling`:
* use `Patches` and `PatchEncoder` from .models
* drop TF1 stuff
* make function / CLI more flexible (expect list of
checkpoint dirs instead of single top-level directory)
- train for `classification`: delegate to `weights_ensembling.run_ensembling`
This commit is contained in:
parent
27f43c175f
commit
bd282a594d
6 changed files with 112 additions and 183 deletions
|
|
@ -1,6 +1,9 @@
|
||||||
"""
|
"""
|
||||||
Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
|
Load libraries with possible race conditions once. This must be imported as the first module of eynollah.
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
|
|
||||||
from ocrd_utils import tf_disable_interactive_logs
|
from ocrd_utils import tf_disable_interactive_logs
|
||||||
from torch import *
|
from torch import *
|
||||||
tf_disable_interactive_logs()
|
tf_disable_interactive_logs()
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli
|
||||||
from .inference import main as inference_cli
|
from .inference import main as inference_cli
|
||||||
from .train import ex
|
from .train import ex
|
||||||
from .extract_line_gt import linegt_cli
|
from .extract_line_gt import linegt_cli
|
||||||
from .weights_ensembling import main as ensemble_cli
|
from .weights_ensembling import ensemble_cli
|
||||||
|
|
||||||
@click.command(context_settings=dict(
|
@click.command(context_settings=dict(
|
||||||
ignore_unknown_options=True,
|
ignore_unknown_options=True,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from eynollah.training.gt_gen_utils import (
|
from .gt_gen_utils import (
|
||||||
filter_contours_area_of_image,
|
filter_contours_area_of_image,
|
||||||
find_format_of_given_filename_in_dir,
|
find_format_of_given_filename_in_dir,
|
||||||
find_new_features_of_contours,
|
find_new_features_of_contours,
|
||||||
|
|
@ -26,6 +26,9 @@ from eynollah.training.gt_gen_utils import (
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def main():
|
def main():
|
||||||
|
"""
|
||||||
|
extract GT data suitable for model training for various tasks
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
|
|
@ -74,6 +77,9 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images):
|
def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, dir_out_images):
|
||||||
|
"""
|
||||||
|
extract PAGE-XML GT data suitable for model training for segmentation tasks
|
||||||
|
"""
|
||||||
if config:
|
if config:
|
||||||
with open(config) as f:
|
with open(config) as f:
|
||||||
config_params = json.load(f)
|
config_params = json.load(f)
|
||||||
|
|
@ -110,6 +116,9 @@ def pagexml2label(dir_xml,dir_out,type_output,config, printspace, dir_images, di
|
||||||
type=click.Path(exists=True, dir_okay=False),
|
type=click.Path(exists=True, dir_okay=False),
|
||||||
)
|
)
|
||||||
def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
|
def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
|
||||||
|
"""
|
||||||
|
extract image GT data suitable for model training for image enhancement tasks
|
||||||
|
"""
|
||||||
ls_imgs = os.listdir(dir_imgs)
|
ls_imgs = os.listdir(dir_imgs)
|
||||||
with open(scales) as f:
|
with open(scales) as f:
|
||||||
scale_dict = json.load(f)
|
scale_dict = json.load(f)
|
||||||
|
|
@ -175,6 +184,9 @@ def image_enhancement(dir_imgs, dir_out_images, dir_out_labels, scales):
|
||||||
)
|
)
|
||||||
|
|
||||||
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early):
|
def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, input_height, input_width, min_area_size, min_area_early):
|
||||||
|
"""
|
||||||
|
extract PAGE-XML GT data suitable for model training for reading-order task
|
||||||
|
"""
|
||||||
xml_files_ind = os.listdir(dir_xml)
|
xml_files_ind = os.listdir(dir_xml)
|
||||||
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
|
xml_files_ind = [ind_xml for ind_xml in xml_files_ind if ind_xml.endswith('.xml')]
|
||||||
input_height = int(input_height)
|
input_height = int(input_height)
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,9 @@ from .metrics import (
|
||||||
soft_dice_loss,
|
soft_dice_loss,
|
||||||
weighted_categorical_crossentropy,
|
weighted_categorical_crossentropy,
|
||||||
)
|
)
|
||||||
|
from.utils import scale_padd_image_for_ocr
|
||||||
|
from ..utils.utils_ocr import decode_batch_predictions
|
||||||
|
|
||||||
from.utils import (scale_padd_image_for_ocr)
|
|
||||||
from eynollah.utils.utils_ocr import (decode_batch_predictions)
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
|
|
||||||
|
|
@ -3,32 +3,8 @@ import sys
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import click
|
|
||||||
|
|
||||||
from eynollah.training.metrics import (
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
soft_dice_loss,
|
|
||||||
weighted_categorical_crossentropy
|
|
||||||
)
|
|
||||||
from eynollah.training.models import (
|
|
||||||
PatchEncoder,
|
|
||||||
Patches,
|
|
||||||
machine_based_reading_order_model,
|
|
||||||
resnet50_classifier,
|
|
||||||
resnet50_unet,
|
|
||||||
vit_resnet50_unet,
|
|
||||||
vit_resnet50_unet_transformer_before_cnn,
|
|
||||||
cnn_rnn_ocr_model,
|
|
||||||
RESNET50_WEIGHTS_PATH,
|
|
||||||
RESNET50_WEIGHTS_URL
|
|
||||||
)
|
|
||||||
from eynollah.training.utils import (
|
|
||||||
data_gen,
|
|
||||||
generate_arrays_from_folder_reading_order,
|
|
||||||
get_one_hot,
|
|
||||||
preprocess_imgs,
|
|
||||||
)
|
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.optimizers import SGD, Adam
|
from tensorflow.keras.optimizers import SGD, Adam
|
||||||
|
|
@ -43,6 +19,31 @@ from sacred.config import create_captured_function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
from .metrics import (
|
||||||
|
soft_dice_loss,
|
||||||
|
weighted_categorical_crossentropy
|
||||||
|
)
|
||||||
|
from .models import (
|
||||||
|
PatchEncoder,
|
||||||
|
Patches,
|
||||||
|
machine_based_reading_order_model,
|
||||||
|
resnet50_classifier,
|
||||||
|
resnet50_unet,
|
||||||
|
vit_resnet50_unet,
|
||||||
|
vit_resnet50_unet_transformer_before_cnn,
|
||||||
|
cnn_rnn_ocr_model,
|
||||||
|
RESNET50_WEIGHTS_PATH,
|
||||||
|
RESNET50_WEIGHTS_URL
|
||||||
|
)
|
||||||
|
from .utils import (
|
||||||
|
data_gen,
|
||||||
|
generate_arrays_from_folder_reading_order,
|
||||||
|
get_one_hot,
|
||||||
|
preprocess_imgs,
|
||||||
|
)
|
||||||
|
from .weights_ensembling import run_ensembling
|
||||||
|
|
||||||
|
|
||||||
class SaveWeightsAfterSteps(ModelCheckpoint):
|
class SaveWeightsAfterSteps(ModelCheckpoint):
|
||||||
def __init__(self, save_interval, save_path, _config, **kwargs):
|
def __init__(self, save_interval, save_path, _config, **kwargs):
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
@ -65,9 +66,7 @@ class SaveWeightsAfterSteps(ModelCheckpoint):
|
||||||
super()._save_handler(filepath)
|
super()._save_handler(filepath)
|
||||||
with open(os.path.join(filepath, "config.json"), "w") as fp:
|
with open(os.path.join(filepath, "config.json"), "w") as fp:
|
||||||
json.dump(self._config, fp) # encode dict into JSON
|
json.dump(self._config, fp) # encode dict into JSON
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def configuration():
|
def configuration():
|
||||||
try:
|
try:
|
||||||
for device in tf.config.list_physical_devices('GPU'):
|
for device in tf.config.list_physical_devices('GPU'):
|
||||||
|
|
@ -272,6 +271,9 @@ def run(_config,
|
||||||
skewing_amplitudes=None,
|
skewing_amplitudes=None,
|
||||||
max_len=None,
|
max_len=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
run configured experiment via sacred
|
||||||
|
"""
|
||||||
|
|
||||||
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
|
if pretraining and not os.path.isfile(RESNET50_WEIGHTS_PATH):
|
||||||
_log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH)
|
_log.info("downloading RESNET50 pretrained weights to %s", RESNET50_WEIGHTS_PATH)
|
||||||
|
|
@ -312,7 +314,7 @@ def run(_config,
|
||||||
|
|
||||||
imgs_list = list(os.listdir(dir_img))
|
imgs_list = list(os.listdir(dir_img))
|
||||||
segs_list = list(os.listdir(dir_seg))
|
segs_list = list(os.listdir(dir_seg))
|
||||||
|
|
||||||
imgs_list_test = list(os.listdir(dir_img_val))
|
imgs_list_test = list(os.listdir(dir_img_val))
|
||||||
segs_list_test = list(os.listdir(dir_seg_val))
|
segs_list_test = list(os.listdir(dir_seg_val))
|
||||||
|
|
||||||
|
|
@ -380,7 +382,7 @@ def run(_config,
|
||||||
num_patches_x = transformer_num_patches_xy[0]
|
num_patches_x = transformer_num_patches_xy[0]
|
||||||
num_patches_y = transformer_num_patches_xy[1]
|
num_patches_y = transformer_num_patches_xy[1]
|
||||||
num_patches = num_patches_x * num_patches_y
|
num_patches = num_patches_x * num_patches_y
|
||||||
|
|
||||||
if transformer_cnn_first:
|
if transformer_cnn_first:
|
||||||
model_builder = vit_resnet50_unet
|
model_builder = vit_resnet50_unet
|
||||||
multiple_of_32 = True
|
multiple_of_32 = True
|
||||||
|
|
@ -413,13 +415,13 @@ def run(_config,
|
||||||
model_builder.config = _config
|
model_builder.config = _config
|
||||||
model_builder.logger = _log
|
model_builder.logger = _log
|
||||||
model = model_builder(num_patches)
|
model = model_builder(num_patches)
|
||||||
|
|
||||||
assert model is not None
|
assert model is not None
|
||||||
#if you want to see the model structure just uncomment model summary.
|
#if you want to see the model structure just uncomment model summary.
|
||||||
#model.summary()
|
#model.summary()
|
||||||
|
|
||||||
if task in ["segmentation", "binarization"]:
|
if task in ["segmentation", "binarization"]:
|
||||||
if is_loss_soft_dice:
|
if is_loss_soft_dice:
|
||||||
loss = soft_dice_loss
|
loss = soft_dice_loss
|
||||||
elif weighted_loss:
|
elif weighted_loss:
|
||||||
loss = weighted_categorical_crossentropy(weights)
|
loss = weighted_categorical_crossentropy(weights)
|
||||||
|
|
@ -434,7 +436,7 @@ def run(_config,
|
||||||
ignore_class=0,
|
ignore_class=0,
|
||||||
sparse_y_true=False,
|
sparse_y_true=False,
|
||||||
sparse_y_pred=False)])
|
sparse_y_pred=False)])
|
||||||
|
|
||||||
# generating train and evaluation data
|
# generating train and evaluation data
|
||||||
gen_kwargs = dict(batch_size=n_batch,
|
gen_kwargs = dict(batch_size=n_batch,
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
|
|
@ -447,7 +449,7 @@ def run(_config,
|
||||||
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
##img_validation_patches = os.listdir(dir_flow_eval_imgs)
|
||||||
##score_best=[]
|
##score_best=[]
|
||||||
##score_best.append(0)
|
##score_best.append(0)
|
||||||
|
|
||||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||||
SaveWeightsAfterSteps(0, dir_output, _config)]
|
SaveWeightsAfterSteps(0, dir_output, _config)]
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
@ -471,7 +473,7 @@ def run(_config,
|
||||||
#os.system('rm -rf '+dir_eval_flowing)
|
#os.system('rm -rf '+dir_eval_flowing)
|
||||||
|
|
||||||
#model.save(dir_output+'/'+'model'+'.h5')
|
#model.save(dir_output+'/'+'model'+'.h5')
|
||||||
|
|
||||||
elif task=="cnn-rnn-ocr":
|
elif task=="cnn-rnn-ocr":
|
||||||
|
|
||||||
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
dir_img, dir_lab = get_dirs_or_files(dir_train)
|
||||||
|
|
@ -480,7 +482,7 @@ def run(_config,
|
||||||
labs_list = list(os.listdir(dir_lab))
|
labs_list = list(os.listdir(dir_lab))
|
||||||
imgs_list_val = list(os.listdir(dir_img_val))
|
imgs_list_val = list(os.listdir(dir_img_val))
|
||||||
labs_list_val = list(os.listdir(dir_lab_val))
|
labs_list_val = list(os.listdir(dir_lab_val))
|
||||||
|
|
||||||
with open(characters_txt_file, 'r') as char_txt_f:
|
with open(characters_txt_file, 'r') as char_txt_f:
|
||||||
characters = json.load(char_txt_f)
|
characters = json.load(char_txt_f)
|
||||||
padding_token = len(characters) + 5
|
padding_token = len(characters) + 5
|
||||||
|
|
@ -533,7 +535,7 @@ def run(_config,
|
||||||
#tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha)
|
#tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, decay_steps, alpha)
|
||||||
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||||
model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer
|
model.compile(optimizer=opt) # rs: loss seems to be (ctc_batch_cost) in last layer
|
||||||
|
|
||||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||||
SaveWeightsAfterSteps(0, dir_output, _config)]
|
SaveWeightsAfterSteps(0, dir_output, _config)]
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
@ -544,7 +546,7 @@ def run(_config,
|
||||||
epochs=n_epochs,
|
epochs=n_epochs,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
initial_epoch=index_start)
|
initial_epoch=index_start)
|
||||||
|
|
||||||
elif task=='classification':
|
elif task=='classification':
|
||||||
if continue_training:
|
if continue_training:
|
||||||
model = load_model(dir_of_start_model, compile=False)
|
model = load_model(dir_of_start_model, compile=False)
|
||||||
|
|
@ -573,7 +575,7 @@ def run(_config,
|
||||||
monitor='val_f1',
|
monitor='val_f1',
|
||||||
#save_best_only=True, # we need all for ensembling
|
#save_best_only=True, # we need all for ensembling
|
||||||
mode='max')]
|
mode='max')]
|
||||||
|
|
||||||
history = model.fit(trainXY,
|
history = model.fit(trainXY,
|
||||||
#class_weight=weights)
|
#class_weight=weights)
|
||||||
validation_data=testXY,
|
validation_data=testXY,
|
||||||
|
|
@ -586,28 +588,12 @@ def run(_config,
|
||||||
f1_threshold_classification)
|
f1_threshold_classification)
|
||||||
if len(usable_checkpoints) >= 1:
|
if len(usable_checkpoints) >= 1:
|
||||||
_log.info("averaging over usable checkpoints: %s", str(usable_checkpoints))
|
_log.info("averaging over usable checkpoints: %s", str(usable_checkpoints))
|
||||||
all_weights = []
|
usable_checkpoints = [os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
|
||||||
for epoch in usable_checkpoints:
|
for epoch in usable_checkpoints]
|
||||||
cp_path = os.path.join(dir_output, 'model_{epoch:02d}'.format(epoch=epoch + 1))
|
ens_path = os.path.join(dir_output, 'model_ens_avg')
|
||||||
assert os.path.isdir(cp_path), cp_path
|
run_ensembling(usable_checkpoints, ens_path)
|
||||||
model = load_model(cp_path, compile=False)
|
_log.info("ensemble model saved under '%s'", ens_path)
|
||||||
all_weights.append(model.get_weights())
|
|
||||||
|
|
||||||
new_weights = []
|
|
||||||
for layer_weights in zip(*all_weights):
|
|
||||||
layer_weights = np.array([np.array(weights).mean(axis=0)
|
|
||||||
for weights in zip(*layer_weights)])
|
|
||||||
new_weights.append(layer_weights)
|
|
||||||
|
|
||||||
#model = tf.keras.models.clone_model(model)
|
|
||||||
model.set_weights(new_weights)
|
|
||||||
|
|
||||||
cp_path = os.path.join(dir_output, 'model_ens_avg')
|
|
||||||
model.save(cp_path)
|
|
||||||
with open(os.path.join(cp_path, "config.json"), "w") as fp:
|
|
||||||
json.dump(_config, fp) # encode dict into JSON
|
|
||||||
_log.info("ensemble model saved under '%s'", cp_path)
|
|
||||||
|
|
||||||
elif task=='reading_order':
|
elif task=='reading_order':
|
||||||
if continue_training:
|
if continue_training:
|
||||||
model = load_model(dir_of_start_model, compile=False)
|
model = load_model(dir_of_start_model, compile=False)
|
||||||
|
|
@ -618,10 +604,10 @@ def run(_config,
|
||||||
input_width,
|
input_width,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
pretraining)
|
pretraining)
|
||||||
|
|
||||||
dir_flow_train_imgs = os.path.join(dir_train, 'images')
|
dir_flow_train_imgs = os.path.join(dir_train, 'images')
|
||||||
dir_flow_train_labels = os.path.join(dir_train, 'labels')
|
dir_flow_train_labels = os.path.join(dir_train, 'labels')
|
||||||
|
|
||||||
classes = os.listdir(dir_flow_train_labels)
|
classes = os.listdir(dir_flow_train_labels)
|
||||||
if augmentation:
|
if augmentation:
|
||||||
num_rows = len(classes)*(len(thetha) + 1)
|
num_rows = len(classes)*(len(thetha) + 1)
|
||||||
|
|
@ -634,7 +620,7 @@ def run(_config,
|
||||||
#optimizer=SGD(learning_rate=0.01, momentum=0.9),
|
#optimizer=SGD(learning_rate=0.01, momentum=0.9),
|
||||||
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
|
optimizer=Adam(learning_rate=0.0001), # rs: why not learning_rate?
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||||
SaveWeightsAfterSteps(0, dir_output, _config)]
|
SaveWeightsAfterSteps(0, dir_output, _config)]
|
||||||
if save_interval:
|
if save_interval:
|
||||||
|
|
@ -657,5 +643,3 @@ def run(_config,
|
||||||
model_dir = os.path.join(dir_out,'model_best')
|
model_dir = os.path.join(dir_out,'model_best')
|
||||||
model.save(model_dir)
|
model.save(model_dir)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,136 +1,66 @@
|
||||||
import sys
|
|
||||||
from glob import glob
|
|
||||||
from os import environ, devnull
|
|
||||||
from os.path import join
|
|
||||||
from warnings import catch_warnings, simplefilter
|
|
||||||
import os
|
import os
|
||||||
|
from warnings import catch_warnings, simplefilter
|
||||||
|
|
||||||
|
import click
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
stderr = sys.stderr
|
|
||||||
sys.stderr = open(devnull, 'w')
|
from ocrd_utils import tf_disable_interactive_logs
|
||||||
|
tf_disable_interactive_logs()
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from tensorflow.python.keras import backend as tensorflow_backend
|
|
||||||
sys.stderr = stderr
|
|
||||||
from tensorflow.keras import layers
|
|
||||||
import tensorflow.keras.losses
|
|
||||||
from tensorflow.keras.layers import *
|
|
||||||
import click
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
from .models import (
|
||||||
class Patches(layers.Layer):
|
PatchEncoder,
|
||||||
def __init__(self, patch_size_x, patch_size_y):
|
Patches,
|
||||||
super(Patches, self).__init__()
|
)
|
||||||
self.patch_size_x = patch_size_x
|
|
||||||
self.patch_size_y = patch_size_y
|
|
||||||
|
|
||||||
def call(self, images):
|
|
||||||
#print(tf.shape(images)[1],'images')
|
|
||||||
#print(self.patch_size,'self.patch_size')
|
|
||||||
batch_size = tf.shape(images)[0]
|
|
||||||
patches = tf.image.extract_patches(
|
|
||||||
images=images,
|
|
||||||
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
|
||||||
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
|
||||||
rates=[1, 1, 1, 1],
|
|
||||||
padding="VALID",
|
|
||||||
)
|
|
||||||
#patch_dims = patches.shape[-1]
|
|
||||||
patch_dims = tf.shape(patches)[-1]
|
|
||||||
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
|
||||||
return patches
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'patch_size_x': self.patch_size_x,
|
|
||||||
'patch_size_y': self.patch_size_y,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(PatchEncoder, self).__init__()
|
|
||||||
self.num_patches = num_patches
|
|
||||||
self.projection = layers.Dense(units=projection_dim)
|
|
||||||
self.position_embedding = layers.Embedding(
|
|
||||||
input_dim=num_patches, output_dim=projection_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, patch):
|
|
||||||
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
|
||||||
encoded = self.projection(patch) + self.position_embedding(positions)
|
|
||||||
return encoded
|
|
||||||
def get_config(self):
|
|
||||||
|
|
||||||
config = super().get_config().copy()
|
|
||||||
config.update({
|
|
||||||
'num_patches': self.num_patches,
|
|
||||||
'projection': self.projection,
|
|
||||||
'position_embedding': self.position_embedding,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
def run_ensembling(model_dirs, out_dir):
|
||||||
def start_new_session():
|
all_weights = []
|
||||||
###config = tf.compat.v1.ConfigProto()
|
|
||||||
###config.gpu_options.allow_growth = True
|
|
||||||
|
|
||||||
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
for model_dir in model_dirs:
|
||||||
###tensorflow_backend.set_session(self.session)
|
assert os.path.isdir(model_dir), model_dir
|
||||||
|
model = load_model(model_dir, compile=False,
|
||||||
config = tf.compat.v1.ConfigProto()
|
custom_objects=dict(PatchEncoder=PatchEncoder,
|
||||||
config.gpu_options.allow_growth = True
|
Patches=Patches))
|
||||||
|
all_weights.append(model.get_weights())
|
||||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
|
||||||
tensorflow_backend.set_session(session)
|
|
||||||
return session
|
|
||||||
|
|
||||||
def run_ensembling(dir_models, out):
|
|
||||||
ls_models = os.listdir(dir_models)
|
|
||||||
|
|
||||||
|
|
||||||
weights=[]
|
|
||||||
|
|
||||||
for model_name in ls_models:
|
|
||||||
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
|
||||||
weights.append(model.get_weights())
|
|
||||||
|
|
||||||
new_weights = list()
|
new_weights = []
|
||||||
|
for layer_weights in zip(*all_weights):
|
||||||
|
layer_weights = np.array([np.array(weights).mean(axis=0)
|
||||||
|
for weights in zip(*layer_weights)])
|
||||||
|
new_weights.append(layer_weights)
|
||||||
|
|
||||||
for weights_list_tuple in zip(*weights):
|
#model = tf.keras.models.clone_model(model)
|
||||||
new_weights.append(
|
|
||||||
[np.array(weights_).mean(axis=0)\
|
|
||||||
for weights_ in zip(*weights_list_tuple)])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
new_weights = [np.array(x) for x in new_weights]
|
|
||||||
|
|
||||||
model.set_weights(new_weights)
|
model.set_weights(new_weights)
|
||||||
model.save(out)
|
|
||||||
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
model.save(out_dir)
|
||||||
|
os.system('cp ' + os.path.join(model_dirs[0], "config.json ") + out_dir + "/")
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
"--dir_models",
|
"--in",
|
||||||
"-dm",
|
"-i",
|
||||||
help="directory of models",
|
help="input directory of checkpoint models to be read",
|
||||||
|
multiple=True,
|
||||||
|
required=True,
|
||||||
type=click.Path(exists=True, file_okay=False),
|
type=click.Path(exists=True, file_okay=False),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--out",
|
"--out",
|
||||||
"-o",
|
"-o",
|
||||||
help="output directory where ensembled model will be written.",
|
help="output directory where ensembled model will be written.",
|
||||||
|
required=True,
|
||||||
type=click.Path(exists=False, file_okay=False),
|
type=click.Path(exists=False, file_okay=False),
|
||||||
)
|
)
|
||||||
|
def ensemble_cli(in_, out):
|
||||||
|
"""
|
||||||
|
mix multiple model weights
|
||||||
|
|
||||||
|
Load a sequence of models and mix them into a single ensemble model
|
||||||
|
by averaging their weights. Write the resulting model.
|
||||||
|
"""
|
||||||
|
run_ensembling(in_, out)
|
||||||
|
|
||||||
def main(dir_models, out):
|
|
||||||
run_ensembling(dir_models, out)
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue