Merge remote-tracking branch 'origin/adding-cnn-rnn-training-script' into 2026-01-29-training

# Conflicts:
#	src/eynollah/training/inference.py
This commit is contained in:
kba 2026-01-29 17:32:08 +01:00
commit f13560726e
7 changed files with 290 additions and 40 deletions

View file

@ -19,7 +19,7 @@ from eynollah.model_zoo import EynollahModelZoo
tf_disable_interactive_logs() tf_disable_interactive_logs()
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend as tensorflow_backend from tensorflow.python.keras import backend as tensorflow_backend
from pathlib import Path
from .utils import is_image_filename from .utils import is_image_filename
def resize_image(img_in, input_height, input_width): def resize_image(img_in, input_height, input_width):
@ -347,7 +347,7 @@ class SbbBinarizer:
self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in) self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in)
for i, image_path in enumerate(ls_imgs): for i, image_path in enumerate(ls_imgs):
self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path) self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path)
image_stem = image_path.split('.')[0] image_stem = Path(image_path).stem
image = cv2.imread(os.path.join(dir_in,image_path) ) image = cv2.imread(os.path.join(dir_in,image_path) )
img_last = 0 img_last = 0
model_file, model = self.models model_file, model = self.models

View file

@ -9,6 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli
from .inference import main as inference_cli from .inference import main as inference_cli
from .train import ex from .train import ex
from .extract_line_gt import linegt_cli from .extract_line_gt import linegt_cli
from .weights_ensembling import main as ensemble_cli
@click.command(context_settings=dict( @click.command(context_settings=dict(
ignore_unknown_options=True, ignore_unknown_options=True,
@ -26,3 +27,4 @@ main.add_command(generate_gt_cli, 'generate-gt')
main.add_command(inference_cli, 'inference') main.add_command(inference_cli, 'inference')
main.add_command(train_cli, 'train') main.add_command(train_cli, 'train')
main.add_command(linegt_cli, 'export_textline_images_and_text') main.add_command(linegt_cli, 'export_textline_images_and_text')
main.add_command(ensemble_cli, 'ensembling')

View file

@ -474,7 +474,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name) img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format)) img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file) co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img) added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)

View file

@ -15,7 +15,7 @@ with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, img): def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img):
alpha = 0.5 alpha = 0.5
blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255 blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255
@ -28,6 +28,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
col_sep = (255, 0, 0) col_sep = (255, 0, 0)
col_marginal = (106, 90, 205) col_marginal = (106, 90, 205)
col_table = (0, 90, 205) col_table = (0, 90, 205)
col_map = (90, 90, 205)
if len(co_image)>0: if len(co_image)>0:
cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour
@ -52,6 +53,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
if len(co_table)>0: if len(co_table)>0:
cv2.drawContours(blank_image, co_table, -1, col_table, thickness=cv2.FILLED) # Fill the contour cv2.drawContours(blank_image, co_table, -1, col_table, thickness=cv2.FILLED) # Fill the contour
if len(co_map)>0:
cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour
img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB) img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB)
@ -231,7 +235,12 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y
con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size ) con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size )
try: try:
co_text_eroded.append(con_eroded[0]) if len(con_eroded)>1:
cnt_size = np.array([cv2.contourArea(con_eroded[j]) for j in range(len(con_eroded))])
cnt = contours[np.argmax(cnt_size)]
co_text_eroded.append(cnt)
else:
co_text_eroded.append(con_eroded[0])
except: except:
co_text_eroded.append(con) co_text_eroded.append(con)
@ -377,6 +386,7 @@ def get_layout_contours_for_visualization(xml_file):
co_sep=[] co_sep=[]
co_img=[] co_img=[]
co_table=[] co_table=[]
co_map=[]
co_noise=[] co_noise=[]
types_text = [] types_text = []
@ -593,6 +603,31 @@ def get_layout_contours_for_visualization(xml_file):
elif vv.tag!=link+'Point' and sumi>=1: elif vv.tag!=link+'Point' and sumi>=1:
break break
co_table.append(np.array(c_t_in)) co_table.append(np.array(c_t_in))
if tag.endswith('}MapRegion') or tag.endswith('}mapregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
for vv in nn.iter():
# check the format of coords
if vv.tag==link+'Coords':
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_map.append(np.array(c_t_in))
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
@ -619,7 +654,7 @@ def get_layout_contours_for_visualization(xml_file):
elif vv.tag!=link+'Point' and sumi>=1: elif vv.tag!=link+'Point' and sumi>=1:
break break
co_noise.append(np.array(c_t_in)) co_noise.append(np.array(c_t_in))
return co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len
def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images): def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images):
""" """
@ -698,12 +733,15 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
_, thresh = cv2.threshold(imgray, 0, 255, 0) _, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))]) cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
cnt = contours[np.argmax(cnt_size)] try:
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt) x, y, w, h = cv2.boundingRect(cnt)
except:
x, y , w, h = 0, 0, x_len, y_len
bb_xywh = [x, y, w, h] bb_xywh = [x, y, w, h]
@ -835,7 +873,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
types_graphic_label = list(types_graphic_dict.values()) types_graphic_label = list(types_graphic_dict.values())
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)] labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)]
region_tags=np.unique([x for x in alltags if x.endswith('Region')]) region_tags=np.unique([x for x in alltags if x.endswith('Region')])
@ -846,6 +884,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
co_sep=[] co_sep=[]
co_img=[] co_img=[]
co_table=[] co_table=[]
co_map=[]
co_noise=[] co_noise=[]
for tag in region_tags: for tag in region_tags:
@ -1056,6 +1095,32 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
elif vv.tag!=link+'Point' and sumi>=1: elif vv.tag!=link+'Point' and sumi>=1:
break break
co_table.append(np.array(c_t_in)) co_table.append(np.array(c_t_in))
if 'mapregion' in keys:
if tag.endswith('}MapRegion') or tag.endswith('}mapregion'):
#print('sth')
for nn in root1.iter(tag):
c_t_in=[]
sumi=0
for vv in nn.iter():
# check the format of coords
if vv.tag==link+'Coords':
coords=bool(vv.attrib)
if coords:
p_h=vv.attrib['points'].split(' ')
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
break
else:
pass
if vv.tag==link+'Point':
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
sumi+=1
#print(vv.tag,'in')
elif vv.tag!=link+'Point' and sumi>=1:
break
co_map.append(np.array(c_t_in))
if 'noiseregion' in keys: if 'noiseregion' in keys:
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'): if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
@ -1129,6 +1194,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
erosion_rate = 0#2 erosion_rate = 0#2
dilation_rate = 3#4 dilation_rate = 3#4
co_table, img_boundary = update_region_contours(co_table, img_boundary, erosion_rate, dilation_rate, y_len, x_len ) co_table, img_boundary = update_region_contours(co_table, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
if "mapregion" in elements_with_artificial_class:
erosion_rate = 0#2
dilation_rate = 3#4
co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
@ -1154,6 +1223,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']]) img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']])
if 'tableregion' in keys: if 'tableregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']]) img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']])
if 'mapregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']])
if 'noiseregion' in keys: if 'noiseregion' in keys:
img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']]) img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']])
@ -1214,6 +1285,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if 'tableregion' in keys: if 'tableregion' in keys:
color_label = config_params['tableregion'] color_label = config_params['tableregion']
img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label)) img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label))
if 'mapregion' in keys:
color_label = config_params['mapregion']
img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label))
if 'noiseregion' in keys: if 'noiseregion' in keys:
color_label = config_params['noiseregion'] color_label = config_params['noiseregion']
img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label)) img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label))

View file

@ -26,6 +26,9 @@ from .models import (
Patches Patches
) )
from.utils import (scale_padd_image_for_ocr)
from eynollah.utils.utils_ocr import (decode_batch_predictions)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
@ -35,8 +38,7 @@ Tool to load model and predict for given image.
""" """
class sbb_predict: class sbb_predict:
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
self.image=image self.image=image
self.dir_in=dir_in self.dir_in=dir_in
self.patches=patches self.patches=patches
@ -48,6 +50,7 @@ class sbb_predict:
self.config_params_model=config_params_model self.config_params_model=config_params_model
self.xml_file = xml_file self.xml_file = xml_file
self.out = out self.out = out
self.cpu = cpu
if min_area: if min_area:
self.min_area = float(min_area) self.min_area = float(min_area)
else: else:
@ -159,21 +162,19 @@ class sbb_predict:
return mIoU return mIoU
def start_new_session_and_model(self): def start_new_session_and_model(self):
if self.task == "cnn-rnn-ocr":
config = tf.compat.v1.ConfigProto() if self.cpu:
config.gpu_options.allow_growth = True os.environ['CUDA_VISIBLE_DEVICES']='-1'
self.model = load_model(self.model_dir)
self.model = tf.keras.models.Model(
self.model.get_layer(name = "image").input,
self.model.get_layer(name = "dense2").output)
else:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session) tensorflow_backend.set_session(session)
#tensorflow.keras.layers.custom_layer = PatchEncoder
#tensorflow.keras.layers.custom_layer = Patches
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
#config = tf.ConfigProto()
#config.gpu_options.allow_growth=True
#self.session = tf.InteractiveSession()
#keras.losses.custom_loss = self.weighted_categorical_crossentropy
#self.model = load_model(self.model_dir , compile=False)
##if self.weights_dir!=None: ##if self.weights_dir!=None:
@ -250,6 +251,30 @@ class sbb_predict:
index_class = np.argmax(label_p_pred[0]) index_class = np.argmax(label_p_pred[0])
print("Predicted Class: {}".format(classes_names[str(int(index_class))])) print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
elif self.task == "cnn-rnn-ocr":
img=cv2.imread(image_dir)
img = scale_padd_image_for_ocr(img, self.config_params_model['input_height'], self.config_params_model['input_width'])
img = img / 255.
with open(os.path.join(self.model_dir, "characters_org.txt"), 'r') as char_txt_f:
characters = json.load(char_txt_f)
AUTOTUNE = tf.data.AUTOTUNE
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
preds = self.model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]), verbose=0)
pred_texts = decode_batch_predictions(preds, num_to_char)
pred_texts = pred_texts[0].replace("[UNK]", "")
return pred_texts
elif self.task == 'reading_order': elif self.task == 'reading_order':
img_height = self.config_params_model['input_height'] img_height = self.config_params_model['input_height']
img_width = self.config_params_model['input_width'] img_width = self.config_params_model['input_width']
@ -580,6 +605,8 @@ class sbb_predict:
elif self.task == 'enhancement': elif self.task == 'enhancement':
if self.save: if self.save:
cv2.imwrite(self.save,res) cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr":
print(f"Detected text: {res}")
else: else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
if self.save: if self.save:
@ -587,9 +614,9 @@ class sbb_predict:
if self.save_layout: if self.save_layout:
cv2.imwrite(self.save_layout, only_layout) cv2.imwrite(self.save_layout, only_layout)
if self.ground_truth: if self.ground_truth:
gt_img=cv2.imread(self.ground_truth) gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0]) self.IoU(gt_img[:,:,0],res[:,:,0])
else: else:
ls_images = os.listdir(self.dir_in) ls_images = os.listdir(self.dir_in)
@ -603,6 +630,8 @@ class sbb_predict:
elif self.task == 'enhancement': elif self.task == 'enhancement':
self.save = os.path.join(self.out, f_name+'.png') self.save = os.path.join(self.out, f_name+'.png')
cv2.imwrite(self.save,res) cv2.imwrite(self.save,res)
elif self.task == "cnn-rnn-ocr":
print(f"Detected text for file name {f_name} is: {res}")
else: else:
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
self.save = os.path.join(self.out, f_name+'_overlayed.png') self.save = os.path.join(self.out, f_name+'_overlayed.png')
@ -610,9 +639,9 @@ class sbb_predict:
self.save_layout = os.path.join(self.out, f_name+'_layout.png') self.save_layout = os.path.join(self.out, f_name+'_layout.png')
cv2.imwrite(self.save_layout, only_layout) cv2.imwrite(self.save_layout, only_layout)
if self.ground_truth: if self.ground_truth:
gt_img=cv2.imread(self.ground_truth) gt_img=cv2.imread(self.ground_truth)
self.IoU(gt_img[:,:,0],res[:,:,0]) self.IoU(gt_img[:,:,0],res[:,:,0])
@ -668,24 +697,29 @@ class sbb_predict:
"-xml", "-xml",
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
) )
@click.option(
"--cpu",
"-cpu",
help="For OCR, the default device is the GPU. If this parameter is set to true, inference will be performed on the CPU",
is_flag=True,
)
@click.option( @click.option(
"--min_area", "--min_area",
"-min", "-min",
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
) )
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
assert image or dir_in, "Either a single image -i or a dir_in -di is required" assert image or dir_in, "Either a single image -i or a dir_in -di is required"
with open(os.path.join(model,'config.json')) as f: with open(os.path.join(model,'config.json')) as f:
config_params_model = json.load(f) config_params_model = json.load(f)
task = config_params_model['task'] task = config_params_model['task']
if task != 'classification' and task != 'reading_order': if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr":
if image and not save: if image and not save:
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s") print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
sys.exit(1) sys.exit(1)
if dir_in and not out: if dir_in and not out:
print("Error: You used one of segmentation or binarization task with dir_in but not set -out") print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
sys.exit(1) sys.exit(1)
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area)
x.run() x.run()

View file

@ -1,7 +1,7 @@
import os import os
import math import math
import random import random
from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
import seaborn as sns import seaborn as sns
@ -32,6 +32,9 @@ def scale_padd_image_for_ocr(img, height, width):
else: else:
width_new = width width_new = width
if width_new <= 0:
width_new = width
img_res= resize_image (img, height, width_new) img_res= resize_image (img, height, width_new)
img_fin = np.ones((height, width, 3))*255 img_fin = np.ones((height, width, 3))*255
@ -1335,7 +1338,8 @@ def data_gen_ocr(
# TODO: Why while True + yield, why not return a list? # TODO: Why while True + yield, why not return a list?
while True: while True:
for i in ls_files_images: for i in ls_files_images:
f_name = i.split('.')[0] print(i, 'i')
f_name = Path(i).stem#.split('.')[0]
txt_inp = open(os.path.join(dir_train, "labels/"+f_name+'.txt'),'r').read().split('\n')[0] txt_inp = open(os.path.join(dir_train, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]

View file

@ -0,0 +1,136 @@
import sys
from glob import glob
from os import environ, devnull
from os.path import join
from warnings import catch_warnings, simplefilter
import os
import numpy as np
from PIL import Image
import cv2
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(devnull, 'w')
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend
sys.stderr = stderr
from tensorflow.keras import layers
import tensorflow.keras.losses
from tensorflow.keras.layers import *
import click
import logging
class Patches(layers.Layer):
def __init__(self, patch_size_x, patch_size_y):
super(Patches, self).__init__()
self.patch_size_x = patch_size_x
self.patch_size_y = patch_size_y
def call(self, images):
#print(tf.shape(images)[1],'images')
#print(self.patch_size,'self.patch_size')
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
strides=[1, self.patch_size_y, self.patch_size_x, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
#patch_dims = patches.shape[-1]
patch_dims = tf.shape(patches)[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def get_config(self):
config = super().get_config().copy()
config.update({
'patch_size_x': self.patch_size_x,
'patch_size_y': self.patch_size_y,
})
return config
class PatchEncoder(layers.Layer):
def __init__(self, **kwargs):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config().copy()
config.update({
'num_patches': self.num_patches,
'projection': self.projection,
'position_embedding': self.position_embedding,
})
return config
def start_new_session():
###config = tf.compat.v1.ConfigProto()
###config.gpu_options.allow_growth = True
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
###tensorflow_backend.set_session(self.session)
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(session)
return session
def run_ensembling(dir_models, out):
ls_models = os.listdir(dir_models)
weights=[]
for model_name in ls_models:
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
weights.append(model.get_weights())
new_weights = list()
for weights_list_tuple in zip(*weights):
new_weights.append(
[np.array(weights_).mean(axis=0)\
for weights_ in zip(*weights_list_tuple)])
new_weights = [np.array(x) for x in new_weights]
model.set_weights(new_weights)
model.save(out)
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
@click.command()
@click.option(
"--dir_models",
"-dm",
help="directory of models",
type=click.Path(exists=True, file_okay=False),
)
@click.option(
"--out",
"-o",
help="output directory where ensembled model will be written.",
type=click.Path(exists=False, file_okay=False),
)
def main(dir_models, out):
run_ensembling(dir_models, out)