mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-01-31 06:36:58 +01:00
Merge remote-tracking branch 'origin/adding-cnn-rnn-training-script' into 2026-01-29-training
# Conflicts: # src/eynollah/training/inference.py
This commit is contained in:
commit
f13560726e
7 changed files with 290 additions and 40 deletions
|
|
@ -19,7 +19,7 @@ from eynollah.model_zoo import EynollahModelZoo
|
||||||
tf_disable_interactive_logs()
|
tf_disable_interactive_logs()
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.keras import backend as tensorflow_backend
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
|
from pathlib import Path
|
||||||
from .utils import is_image_filename
|
from .utils import is_image_filename
|
||||||
|
|
||||||
def resize_image(img_in, input_height, input_width):
|
def resize_image(img_in, input_height, input_width):
|
||||||
|
|
@ -347,7 +347,7 @@ class SbbBinarizer:
|
||||||
self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in)
|
self.logger.info("Found %d image files to binarize in %s", len(ls_imgs), dir_in)
|
||||||
for i, image_path in enumerate(ls_imgs):
|
for i, image_path in enumerate(ls_imgs):
|
||||||
self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path)
|
self.logger.info('Binarizing [%3d/%d] %s', i + 1, len(ls_imgs), image_path)
|
||||||
image_stem = image_path.split('.')[0]
|
image_stem = Path(image_path).stem
|
||||||
image = cv2.imread(os.path.join(dir_in,image_path) )
|
image = cv2.imread(os.path.join(dir_in,image_path) )
|
||||||
img_last = 0
|
img_last = 0
|
||||||
model_file, model = self.models
|
model_file, model = self.models
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from .generate_gt_for_training import main as generate_gt_cli
|
||||||
from .inference import main as inference_cli
|
from .inference import main as inference_cli
|
||||||
from .train import ex
|
from .train import ex
|
||||||
from .extract_line_gt import linegt_cli
|
from .extract_line_gt import linegt_cli
|
||||||
|
from .weights_ensembling import main as ensemble_cli
|
||||||
|
|
||||||
@click.command(context_settings=dict(
|
@click.command(context_settings=dict(
|
||||||
ignore_unknown_options=True,
|
ignore_unknown_options=True,
|
||||||
|
|
@ -26,3 +27,4 @@ main.add_command(generate_gt_cli, 'generate-gt')
|
||||||
main.add_command(inference_cli, 'inference')
|
main.add_command(inference_cli, 'inference')
|
||||||
main.add_command(train_cli, 'train')
|
main.add_command(train_cli, 'train')
|
||||||
main.add_command(linegt_cli, 'export_textline_images_and_text')
|
main.add_command(linegt_cli, 'export_textline_images_and_text')
|
||||||
|
main.add_command(ensemble_cli, 'ensembling')
|
||||||
|
|
|
||||||
|
|
@ -474,7 +474,7 @@ def visualize_layout_segmentation(xml_file, dir_xml, dir_out, dir_imgs):
|
||||||
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
|
img_file_name_with_format = find_format_of_given_filename_in_dir(dir_imgs, f_name)
|
||||||
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
|
img = cv2.imread(os.path.join(dir_imgs, img_file_name_with_format))
|
||||||
|
|
||||||
co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
|
co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len = get_layout_contours_for_visualization(xml_file)
|
||||||
|
|
||||||
|
|
||||||
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
|
added_image = visualize_image_from_contours_layout(co_text['paragraph'], co_text['header']+co_text['heading'], co_text['drop-capital'], co_sep, co_img, co_text['marginalia'], co_table, img)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
|
|
||||||
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, img):
|
def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_image, co_marginal, co_table, co_map, img):
|
||||||
alpha = 0.5
|
alpha = 0.5
|
||||||
|
|
||||||
blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255
|
blank_image = np.ones( (img.shape[:]), dtype=np.uint8) * 255
|
||||||
|
|
@ -28,6 +28,7 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
|
||||||
col_sep = (255, 0, 0)
|
col_sep = (255, 0, 0)
|
||||||
col_marginal = (106, 90, 205)
|
col_marginal = (106, 90, 205)
|
||||||
col_table = (0, 90, 205)
|
col_table = (0, 90, 205)
|
||||||
|
col_map = (90, 90, 205)
|
||||||
|
|
||||||
if len(co_image)>0:
|
if len(co_image)>0:
|
||||||
cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour
|
cv2.drawContours(blank_image, co_image, -1, col_image, thickness=cv2.FILLED) # Fill the contour
|
||||||
|
|
@ -53,6 +54,9 @@ def visualize_image_from_contours_layout(co_par, co_header, co_drop, co_sep, co_
|
||||||
if len(co_table)>0:
|
if len(co_table)>0:
|
||||||
cv2.drawContours(blank_image, co_table, -1, col_table, thickness=cv2.FILLED) # Fill the contour
|
cv2.drawContours(blank_image, co_table, -1, col_table, thickness=cv2.FILLED) # Fill the contour
|
||||||
|
|
||||||
|
if len(co_map)>0:
|
||||||
|
cv2.drawContours(blank_image, co_map, -1, col_map, thickness=cv2.FILLED) # Fill the contour
|
||||||
|
|
||||||
img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB)
|
img_final =cv2.cvtColor(blank_image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
added_image = cv2.addWeighted(img,alpha,img_final,1- alpha,0)
|
added_image = cv2.addWeighted(img,alpha,img_final,1- alpha,0)
|
||||||
|
|
@ -231,7 +235,12 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y
|
||||||
con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size )
|
con_eroded = return_contours_of_interested_region(img_boundary_in,pixel, min_size )
|
||||||
|
|
||||||
try:
|
try:
|
||||||
co_text_eroded.append(con_eroded[0])
|
if len(con_eroded)>1:
|
||||||
|
cnt_size = np.array([cv2.contourArea(con_eroded[j]) for j in range(len(con_eroded))])
|
||||||
|
cnt = contours[np.argmax(cnt_size)]
|
||||||
|
co_text_eroded.append(cnt)
|
||||||
|
else:
|
||||||
|
co_text_eroded.append(con_eroded[0])
|
||||||
except:
|
except:
|
||||||
co_text_eroded.append(con)
|
co_text_eroded.append(con)
|
||||||
|
|
||||||
|
|
@ -377,6 +386,7 @@ def get_layout_contours_for_visualization(xml_file):
|
||||||
co_sep=[]
|
co_sep=[]
|
||||||
co_img=[]
|
co_img=[]
|
||||||
co_table=[]
|
co_table=[]
|
||||||
|
co_map=[]
|
||||||
co_noise=[]
|
co_noise=[]
|
||||||
|
|
||||||
types_text = []
|
types_text = []
|
||||||
|
|
@ -594,6 +604,31 @@ def get_layout_contours_for_visualization(xml_file):
|
||||||
break
|
break
|
||||||
co_table.append(np.array(c_t_in))
|
co_table.append(np.array(c_t_in))
|
||||||
|
|
||||||
|
if tag.endswith('}MapRegion') or tag.endswith('}mapregion'):
|
||||||
|
#print('sth')
|
||||||
|
for nn in root1.iter(tag):
|
||||||
|
c_t_in=[]
|
||||||
|
sumi=0
|
||||||
|
for vv in nn.iter():
|
||||||
|
# check the format of coords
|
||||||
|
if vv.tag==link+'Coords':
|
||||||
|
coords=bool(vv.attrib)
|
||||||
|
if coords:
|
||||||
|
p_h=vv.attrib['points'].split(' ')
|
||||||
|
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if vv.tag==link+'Point':
|
||||||
|
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||||
|
sumi+=1
|
||||||
|
#print(vv.tag,'in')
|
||||||
|
elif vv.tag!=link+'Point' and sumi>=1:
|
||||||
|
break
|
||||||
|
co_map.append(np.array(c_t_in))
|
||||||
|
|
||||||
|
|
||||||
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
|
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
|
||||||
#print('sth')
|
#print('sth')
|
||||||
|
|
@ -619,7 +654,7 @@ def get_layout_contours_for_visualization(xml_file):
|
||||||
elif vv.tag!=link+'Point' and sumi>=1:
|
elif vv.tag!=link+'Point' and sumi>=1:
|
||||||
break
|
break
|
||||||
co_noise.append(np.array(c_t_in))
|
co_noise.append(np.array(c_t_in))
|
||||||
return co_text, co_graphic, co_sep, co_img, co_table, co_noise, y_len, x_len
|
return co_text, co_graphic, co_sep, co_img, co_table, co_map, co_noise, y_len, x_len
|
||||||
|
|
||||||
def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images):
|
def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images):
|
||||||
"""
|
"""
|
||||||
|
|
@ -701,9 +736,12 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
|
|
||||||
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
|
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
|
||||||
|
|
||||||
cnt = contours[np.argmax(cnt_size)]
|
try:
|
||||||
|
cnt = contours[np.argmax(cnt_size)]
|
||||||
|
x, y, w, h = cv2.boundingRect(cnt)
|
||||||
|
except:
|
||||||
|
x, y , w, h = 0, 0, x_len, y_len
|
||||||
|
|
||||||
x, y, w, h = cv2.boundingRect(cnt)
|
|
||||||
bb_xywh = [x, y, w, h]
|
bb_xywh = [x, y, w, h]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -835,7 +873,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
types_graphic_label = list(types_graphic_dict.values())
|
types_graphic_label = list(types_graphic_dict.values())
|
||||||
|
|
||||||
|
|
||||||
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0)]
|
labels_rgb_color = [ (0,0,0), (255,0,0), (255,125,0), (255,0,125), (125,255,125), (125,125,0), (0,125,255), (0,125,0), (125,125,125), (255,0,255), (125,0,125), (0,255,0),(0,0,255), (0,255,255), (255,125,125), (0,125,125), (0,255,125), (255,125,255), (125,255,0), (125,255,255)]
|
||||||
|
|
||||||
|
|
||||||
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
|
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
|
||||||
|
|
@ -846,6 +884,7 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
co_sep=[]
|
co_sep=[]
|
||||||
co_img=[]
|
co_img=[]
|
||||||
co_table=[]
|
co_table=[]
|
||||||
|
co_map=[]
|
||||||
co_noise=[]
|
co_noise=[]
|
||||||
|
|
||||||
for tag in region_tags:
|
for tag in region_tags:
|
||||||
|
|
@ -1057,6 +1096,32 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
break
|
break
|
||||||
co_table.append(np.array(c_t_in))
|
co_table.append(np.array(c_t_in))
|
||||||
|
|
||||||
|
if 'mapregion' in keys:
|
||||||
|
if tag.endswith('}MapRegion') or tag.endswith('}mapregion'):
|
||||||
|
#print('sth')
|
||||||
|
for nn in root1.iter(tag):
|
||||||
|
c_t_in=[]
|
||||||
|
sumi=0
|
||||||
|
for vv in nn.iter():
|
||||||
|
# check the format of coords
|
||||||
|
if vv.tag==link+'Coords':
|
||||||
|
coords=bool(vv.attrib)
|
||||||
|
if coords:
|
||||||
|
p_h=vv.attrib['points'].split(' ')
|
||||||
|
c_t_in.append( np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) )
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if vv.tag==link+'Point':
|
||||||
|
c_t_in.append([ int(float(vv.attrib['x'])) , int(float(vv.attrib['y'])) ])
|
||||||
|
sumi+=1
|
||||||
|
#print(vv.tag,'in')
|
||||||
|
elif vv.tag!=link+'Point' and sumi>=1:
|
||||||
|
break
|
||||||
|
co_map.append(np.array(c_t_in))
|
||||||
|
|
||||||
if 'noiseregion' in keys:
|
if 'noiseregion' in keys:
|
||||||
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
|
if tag.endswith('}NoiseRegion') or tag.endswith('}noiseregion'):
|
||||||
#print('sth')
|
#print('sth')
|
||||||
|
|
@ -1129,6 +1194,10 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
erosion_rate = 0#2
|
erosion_rate = 0#2
|
||||||
dilation_rate = 3#4
|
dilation_rate = 3#4
|
||||||
co_table, img_boundary = update_region_contours(co_table, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
|
co_table, img_boundary = update_region_contours(co_table, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
|
||||||
|
if "mapregion" in elements_with_artificial_class:
|
||||||
|
erosion_rate = 0#2
|
||||||
|
dilation_rate = 3#4
|
||||||
|
co_map, img_boundary = update_region_contours(co_map, img_boundary, erosion_rate, dilation_rate, y_len, x_len )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1154,6 +1223,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']])
|
img_poly=cv2.fillPoly(img, pts =co_img, color=labels_rgb_color[ config_params['imageregion']])
|
||||||
if 'tableregion' in keys:
|
if 'tableregion' in keys:
|
||||||
img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']])
|
img_poly=cv2.fillPoly(img, pts =co_table, color=labels_rgb_color[ config_params['tableregion']])
|
||||||
|
if 'mapregion' in keys:
|
||||||
|
img_poly=cv2.fillPoly(img, pts =co_map, color=labels_rgb_color[ config_params['mapregion']])
|
||||||
if 'noiseregion' in keys:
|
if 'noiseregion' in keys:
|
||||||
img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']])
|
img_poly=cv2.fillPoly(img, pts =co_noise, color=labels_rgb_color[ config_params['noiseregion']])
|
||||||
|
|
||||||
|
|
@ -1214,6 +1285,9 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
|
||||||
if 'tableregion' in keys:
|
if 'tableregion' in keys:
|
||||||
color_label = config_params['tableregion']
|
color_label = config_params['tableregion']
|
||||||
img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label))
|
img_poly=cv2.fillPoly(img, pts =co_table, color=(color_label,color_label,color_label))
|
||||||
|
if 'mapregion' in keys:
|
||||||
|
color_label = config_params['mapregion']
|
||||||
|
img_poly=cv2.fillPoly(img, pts =co_map, color=(color_label,color_label,color_label))
|
||||||
if 'noiseregion' in keys:
|
if 'noiseregion' in keys:
|
||||||
color_label = config_params['noiseregion']
|
color_label = config_params['noiseregion']
|
||||||
img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label))
|
img_poly=cv2.fillPoly(img, pts =co_noise, color=(color_label,color_label,color_label))
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,9 @@ from .models import (
|
||||||
Patches
|
Patches
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from.utils import (scale_padd_image_for_ocr)
|
||||||
|
from eynollah.utils.utils_ocr import (decode_batch_predictions)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
|
|
@ -35,8 +38,7 @@ Tool to load model and predict for given image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class sbb_predict:
|
class sbb_predict:
|
||||||
|
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||||
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
|
||||||
self.image=image
|
self.image=image
|
||||||
self.dir_in=dir_in
|
self.dir_in=dir_in
|
||||||
self.patches=patches
|
self.patches=patches
|
||||||
|
|
@ -48,6 +50,7 @@ class sbb_predict:
|
||||||
self.config_params_model=config_params_model
|
self.config_params_model=config_params_model
|
||||||
self.xml_file = xml_file
|
self.xml_file = xml_file
|
||||||
self.out = out
|
self.out = out
|
||||||
|
self.cpu = cpu
|
||||||
if min_area:
|
if min_area:
|
||||||
self.min_area = float(min_area)
|
self.min_area = float(min_area)
|
||||||
else:
|
else:
|
||||||
|
|
@ -159,21 +162,19 @@ class sbb_predict:
|
||||||
return mIoU
|
return mIoU
|
||||||
|
|
||||||
def start_new_session_and_model(self):
|
def start_new_session_and_model(self):
|
||||||
|
if self.task == "cnn-rnn-ocr":
|
||||||
|
if self.cpu:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES']='-1'
|
||||||
|
self.model = load_model(self.model_dir)
|
||||||
|
self.model = tf.keras.models.Model(
|
||||||
|
self.model.get_layer(name = "image").input,
|
||||||
|
self.model.get_layer(name = "dense2").output)
|
||||||
|
else:
|
||||||
|
config = tf.compat.v1.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
config = tf.compat.v1.ConfigProto()
|
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||||
config.gpu_options.allow_growth = True
|
tensorflow_backend.set_session(session)
|
||||||
|
|
||||||
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
|
||||||
tensorflow_backend.set_session(session)
|
|
||||||
#tensorflow.keras.layers.custom_layer = PatchEncoder
|
|
||||||
#tensorflow.keras.layers.custom_layer = Patches
|
|
||||||
self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
|
||||||
#config = tf.ConfigProto()
|
|
||||||
#config.gpu_options.allow_growth=True
|
|
||||||
|
|
||||||
#self.session = tf.InteractiveSession()
|
|
||||||
#keras.losses.custom_loss = self.weighted_categorical_crossentropy
|
|
||||||
#self.model = load_model(self.model_dir , compile=False)
|
|
||||||
|
|
||||||
|
|
||||||
##if self.weights_dir!=None:
|
##if self.weights_dir!=None:
|
||||||
|
|
@ -250,6 +251,30 @@ class sbb_predict:
|
||||||
index_class = np.argmax(label_p_pred[0])
|
index_class = np.argmax(label_p_pred[0])
|
||||||
|
|
||||||
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||||
|
elif self.task == "cnn-rnn-ocr":
|
||||||
|
img=cv2.imread(image_dir)
|
||||||
|
img = scale_padd_image_for_ocr(img, self.config_params_model['input_height'], self.config_params_model['input_width'])
|
||||||
|
|
||||||
|
img = img / 255.
|
||||||
|
|
||||||
|
with open(os.path.join(self.model_dir, "characters_org.txt"), 'r') as char_txt_f:
|
||||||
|
characters = json.load(char_txt_f)
|
||||||
|
|
||||||
|
AUTOTUNE = tf.data.AUTOTUNE
|
||||||
|
|
||||||
|
# Mapping characters to integers.
|
||||||
|
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
|
||||||
|
|
||||||
|
# Mapping integers back to original characters.
|
||||||
|
num_to_char = StringLookup(
|
||||||
|
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
||||||
|
)
|
||||||
|
preds = self.model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]), verbose=0)
|
||||||
|
pred_texts = decode_batch_predictions(preds, num_to_char)
|
||||||
|
pred_texts = pred_texts[0].replace("[UNK]", "")
|
||||||
|
return pred_texts
|
||||||
|
|
||||||
|
|
||||||
elif self.task == 'reading_order':
|
elif self.task == 'reading_order':
|
||||||
img_height = self.config_params_model['input_height']
|
img_height = self.config_params_model['input_height']
|
||||||
img_width = self.config_params_model['input_width']
|
img_width = self.config_params_model['input_width']
|
||||||
|
|
@ -580,6 +605,8 @@ class sbb_predict:
|
||||||
elif self.task == 'enhancement':
|
elif self.task == 'enhancement':
|
||||||
if self.save:
|
if self.save:
|
||||||
cv2.imwrite(self.save,res)
|
cv2.imwrite(self.save,res)
|
||||||
|
elif self.task == "cnn-rnn-ocr":
|
||||||
|
print(f"Detected text: {res}")
|
||||||
else:
|
else:
|
||||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
if self.save:
|
if self.save:
|
||||||
|
|
@ -587,9 +614,9 @@ class sbb_predict:
|
||||||
if self.save_layout:
|
if self.save_layout:
|
||||||
cv2.imwrite(self.save_layout, only_layout)
|
cv2.imwrite(self.save_layout, only_layout)
|
||||||
|
|
||||||
if self.ground_truth:
|
if self.ground_truth:
|
||||||
gt_img=cv2.imread(self.ground_truth)
|
gt_img=cv2.imread(self.ground_truth)
|
||||||
self.IoU(gt_img[:,:,0],res[:,:,0])
|
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ls_images = os.listdir(self.dir_in)
|
ls_images = os.listdir(self.dir_in)
|
||||||
|
|
@ -603,6 +630,8 @@ class sbb_predict:
|
||||||
elif self.task == 'enhancement':
|
elif self.task == 'enhancement':
|
||||||
self.save = os.path.join(self.out, f_name+'.png')
|
self.save = os.path.join(self.out, f_name+'.png')
|
||||||
cv2.imwrite(self.save,res)
|
cv2.imwrite(self.save,res)
|
||||||
|
elif self.task == "cnn-rnn-ocr":
|
||||||
|
print(f"Detected text for file name {f_name} is: {res}")
|
||||||
else:
|
else:
|
||||||
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
self.save = os.path.join(self.out, f_name+'_overlayed.png')
|
self.save = os.path.join(self.out, f_name+'_overlayed.png')
|
||||||
|
|
@ -610,9 +639,9 @@ class sbb_predict:
|
||||||
self.save_layout = os.path.join(self.out, f_name+'_layout.png')
|
self.save_layout = os.path.join(self.out, f_name+'_layout.png')
|
||||||
cv2.imwrite(self.save_layout, only_layout)
|
cv2.imwrite(self.save_layout, only_layout)
|
||||||
|
|
||||||
if self.ground_truth:
|
if self.ground_truth:
|
||||||
gt_img=cv2.imread(self.ground_truth)
|
gt_img=cv2.imread(self.ground_truth)
|
||||||
self.IoU(gt_img[:,:,0],res[:,:,0])
|
self.IoU(gt_img[:,:,0],res[:,:,0])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -668,24 +697,29 @@ class sbb_predict:
|
||||||
"-xml",
|
"-xml",
|
||||||
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
|
help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--cpu",
|
||||||
|
"-cpu",
|
||||||
|
help="For OCR, the default device is the GPU. If this parameter is set to true, inference will be performed on the CPU",
|
||||||
|
is_flag=True,
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--min_area",
|
"--min_area",
|
||||||
"-min",
|
"-min",
|
||||||
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.",
|
||||||
)
|
)
|
||||||
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
def main(image, dir_in, model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area):
|
||||||
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
assert image or dir_in, "Either a single image -i or a dir_in -di is required"
|
||||||
with open(os.path.join(model,'config.json')) as f:
|
with open(os.path.join(model,'config.json')) as f:
|
||||||
config_params_model = json.load(f)
|
config_params_model = json.load(f)
|
||||||
task = config_params_model['task']
|
task = config_params_model['task']
|
||||||
if task != 'classification' and task != 'reading_order':
|
if task != 'classification' and task != 'reading_order' and task != "cnn-rnn-ocr":
|
||||||
if image and not save:
|
if image and not save:
|
||||||
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
print("Error: You used one of segmentation or binarization task with image input but not set -s, you need a filename to save visualized output with -s")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if dir_in and not out:
|
if dir_in and not out:
|
||||||
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
|
print("Error: You used one of segmentation or binarization task with dir_in but not set -out")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area)
|
x=sbb_predict(image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, cpu, out, min_area)
|
||||||
x.run()
|
x.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
from pathlib import Path
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
@ -32,6 +32,9 @@ def scale_padd_image_for_ocr(img, height, width):
|
||||||
else:
|
else:
|
||||||
width_new = width
|
width_new = width
|
||||||
|
|
||||||
|
if width_new <= 0:
|
||||||
|
width_new = width
|
||||||
|
|
||||||
img_res= resize_image (img, height, width_new)
|
img_res= resize_image (img, height, width_new)
|
||||||
img_fin = np.ones((height, width, 3))*255
|
img_fin = np.ones((height, width, 3))*255
|
||||||
|
|
||||||
|
|
@ -1335,7 +1338,8 @@ def data_gen_ocr(
|
||||||
# TODO: Why while True + yield, why not return a list?
|
# TODO: Why while True + yield, why not return a list?
|
||||||
while True:
|
while True:
|
||||||
for i in ls_files_images:
|
for i in ls_files_images:
|
||||||
f_name = i.split('.')[0]
|
print(i, 'i')
|
||||||
|
f_name = Path(i).stem#.split('.')[0]
|
||||||
|
|
||||||
txt_inp = open(os.path.join(dir_train, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]
|
txt_inp = open(os.path.join(dir_train, "labels/"+f_name+'.txt'),'r').read().split('\n')[0]
|
||||||
|
|
||||||
|
|
|
||||||
136
src/eynollah/training/weights_ensembling.py
Normal file
136
src/eynollah/training/weights_ensembling.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
import sys
|
||||||
|
from glob import glob
|
||||||
|
from os import environ, devnull
|
||||||
|
from os.path import join
|
||||||
|
from warnings import catch_warnings, simplefilter
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
stderr = sys.stderr
|
||||||
|
sys.stderr = open(devnull, 'w')
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
from tensorflow.python.keras import backend as tensorflow_backend
|
||||||
|
sys.stderr = stderr
|
||||||
|
from tensorflow.keras import layers
|
||||||
|
import tensorflow.keras.losses
|
||||||
|
from tensorflow.keras.layers import *
|
||||||
|
import click
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class Patches(layers.Layer):
|
||||||
|
def __init__(self, patch_size_x, patch_size_y):
|
||||||
|
super(Patches, self).__init__()
|
||||||
|
self.patch_size_x = patch_size_x
|
||||||
|
self.patch_size_y = patch_size_y
|
||||||
|
|
||||||
|
def call(self, images):
|
||||||
|
#print(tf.shape(images)[1],'images')
|
||||||
|
#print(self.patch_size,'self.patch_size')
|
||||||
|
batch_size = tf.shape(images)[0]
|
||||||
|
patches = tf.image.extract_patches(
|
||||||
|
images=images,
|
||||||
|
sizes=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
|
strides=[1, self.patch_size_y, self.patch_size_x, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
)
|
||||||
|
#patch_dims = patches.shape[-1]
|
||||||
|
patch_dims = tf.shape(patches)[-1]
|
||||||
|
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
|
||||||
|
return patches
|
||||||
|
def get_config(self):
|
||||||
|
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'patch_size_x': self.patch_size_x,
|
||||||
|
'patch_size_y': self.patch_size_y,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEncoder(layers.Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(PatchEncoder, self).__init__()
|
||||||
|
self.num_patches = num_patches
|
||||||
|
self.projection = layers.Dense(units=projection_dim)
|
||||||
|
self.position_embedding = layers.Embedding(
|
||||||
|
input_dim=num_patches, output_dim=projection_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, patch):
|
||||||
|
positions = tf.range(start=0, limit=self.num_patches, delta=1)
|
||||||
|
encoded = self.projection(patch) + self.position_embedding(positions)
|
||||||
|
return encoded
|
||||||
|
def get_config(self):
|
||||||
|
|
||||||
|
config = super().get_config().copy()
|
||||||
|
config.update({
|
||||||
|
'num_patches': self.num_patches,
|
||||||
|
'projection': self.projection,
|
||||||
|
'position_embedding': self.position_embedding,
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def start_new_session():
|
||||||
|
###config = tf.compat.v1.ConfigProto()
|
||||||
|
###config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
|
###self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||||
|
###tensorflow_backend.set_session(self.session)
|
||||||
|
|
||||||
|
config = tf.compat.v1.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
|
session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
|
||||||
|
tensorflow_backend.set_session(session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
def run_ensembling(dir_models, out):
|
||||||
|
ls_models = os.listdir(dir_models)
|
||||||
|
|
||||||
|
|
||||||
|
weights=[]
|
||||||
|
|
||||||
|
for model_name in ls_models:
|
||||||
|
model = load_model(os.path.join(dir_models,model_name) , compile=False, custom_objects={'PatchEncoder':PatchEncoder, 'Patches': Patches})
|
||||||
|
weights.append(model.get_weights())
|
||||||
|
|
||||||
|
new_weights = list()
|
||||||
|
|
||||||
|
for weights_list_tuple in zip(*weights):
|
||||||
|
new_weights.append(
|
||||||
|
[np.array(weights_).mean(axis=0)\
|
||||||
|
for weights_ in zip(*weights_list_tuple)])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
new_weights = [np.array(x) for x in new_weights]
|
||||||
|
|
||||||
|
model.set_weights(new_weights)
|
||||||
|
model.save(out)
|
||||||
|
os.system('cp '+os.path.join(os.path.join(dir_models,model_name) , "config.json ")+out)
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option(
|
||||||
|
"--dir_models",
|
||||||
|
"-dm",
|
||||||
|
help="directory of models",
|
||||||
|
type=click.Path(exists=True, file_okay=False),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--out",
|
||||||
|
"-o",
|
||||||
|
help="output directory where ensembled model will be written.",
|
||||||
|
type=click.Path(exists=False, file_okay=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def main(dir_models, out):
|
||||||
|
run_ensembling(dir_models, out)
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue