training.gt_gen_utils: fix+simplify cropping…

when parsing `PrintSpace` or `Border` from PAGE-XML,
- use `lxml` XPath instead of nested loops
- convert points to polygons directly
  (instead of painting on canvas and retrieving contours)
- pass result bbox in slice notation
  (instead of xywh)
This commit is contained in:
Robert Sachunsky 2026-01-28 13:42:59 +01:00
parent acda9c84ee
commit 0372fd7a1e
2 changed files with 51 additions and 118 deletions

View file

@ -1,15 +1,18 @@
import os
import numpy as np
import warnings
import xml.etree.ElementTree as ET
from lxml import etree as ET
from tqdm import tqdm
import cv2
from shapely import geometry
from pathlib import Path
from PIL import ImageFont
from ocrd_utils import bbox_from_points
KERNEL = np.ones((5, 5), np.uint8)
NS = { 'pc': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'
}
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@ -664,52 +667,13 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
y_new = int ( x_new * (y_len / float(x_len)) )
if printspace or "printspace_as_class_in_layout" in list(config_params.keys()):
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')])
co_use_case = []
for tag in region_tags:
tag_endings = ['}PrintSpace','}Border']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) +
root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS))
if len(ps):
points = ps[0].find('pc:Coords', NS).get('points')
ps_bbox = bbox_from_points(points)
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
img = np.zeros((y_len, x_len, 3))
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
img_poly = img_poly.astype(np.uint8)
imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
bb_xywh = [x, y, w, h]
ps_bbox = [0, 0, None, None]
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
@ -791,7 +755,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if printspace and config_params['use_case']!='printspace':
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
@ -815,7 +780,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace and config_params['use_case']!='printspace':
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
img_org = img_org[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new)
@ -1194,7 +1160,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if "printspace_as_class_in_layout" in list(config_params.keys()):
printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1]))
printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1
printspace_mask[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2]] = 1
img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_rgb_color[0]
img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_rgb_color[1]
@ -1252,7 +1219,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if "printspace_as_class_in_layout" in list(config_params.keys()):
printspace_mask = np.zeros((img_poly.shape[0], img_poly.shape[1]))
printspace_mask[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2]] = 1
printspace_mask[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2]] = 1
img_poly[:,:,0][printspace_mask[:,:] == 0] = printspace_class_label
img_poly[:,:,1][printspace_mask[:,:] == 0] = printspace_class_label
@ -1261,7 +1229,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if printspace:
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col:
img_poly = resize_image(img_poly, y_new, x_new)
@ -1285,7 +1254,8 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace:
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
img_org = img_org[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
if 'columns_width' in list(config_params.keys()) and num_col:
img_org = resize_image(img_org, y_new, x_new)
@ -1326,6 +1296,7 @@ def find_new_features_of_contours(contours_main):
y_max_main = np.array([np.max(contours_main[j][:, 1]) for j in range(len(contours_main))])
return cx_main, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, y_corr_x_min_from_argmin
def read_xml(xml_file):
file_name = Path(xml_file).stem
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
@ -1344,57 +1315,13 @@ def read_xml(xml_file):
index_tot_regions.append(jj.attrib['index'])
tot_region_ref.append(jj.attrib['regionRef'])
if (link+'PrintSpace' in alltags) or (link+'Border' in alltags):
co_printspace = []
if link+'PrintSpace' in alltags:
region_tags_printspace = np.unique([x for x in alltags if x.endswith('PrintSpace')])
elif link+'Border' in alltags:
region_tags_printspace = np.unique([x for x in alltags if x.endswith('Border')])
for tag in region_tags_printspace:
if link+'PrintSpace' in alltags:
tag_endings_printspace = ['}PrintSpace','}printspace']
elif link+'Border' in alltags:
tag_endings_printspace = ['}Border','}border']
if tag.endswith(tag_endings_printspace[0]) or tag.endswith(tag_endings_printspace[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
ps = (root1.xpath('/pc:PcGts/pc:Page/pc:Border', namespaces=NS) +
root1.xpath('/pc:PcGts/pc:Page/pc:PrintSpace', namespaces=NS))
if len(ps):
points = ps[0].find('pc:Coords', NS).get('points')
ps_bbox = bbox_from_points(points)
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_printspace.append(np.array(c_t_in))
img_printspace = np.zeros( (y_len,x_len,3) )
img_printspace=cv2.fillPoly(img_printspace, pts =co_printspace, color=(1,1,1))
img_printspace = img_printspace.astype(np.uint8)
imgray = cv2.cvtColor(img_printspace, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
bb_coord_printspace = [x, y, w, h]
else:
bb_coord_printspace = None
ps_bbox = [0, 0, None, None]
region_tags=np.unique([x for x in alltags if x.endswith('Region')])
co_text_paragraph=[]
@ -1749,11 +1676,19 @@ def read_xml(xml_file):
img_poly=cv2.fillPoly(img, pts =co_img, color=(4,4,4))
img_poly=cv2.fillPoly(img, pts =co_sep, color=(5,5,5))
return tree1, root1, bb_coord_printspace, file_name, id_paragraph, id_header+id_heading, co_text_paragraph, co_text_header+co_text_heading,\
tot_region_ref,x_len, y_len,index_tot_regions, img_poly
return (tree1,
root1,
ps_bbox,
file_name,
id_paragraph,
id_header + id_heading,
co_text_paragraph,
co_text_header + co_text_heading,
tot_region_ref,
x_len,
y_len,
index_tot_regions,
img_poly)
def bounding_box(cnt,color, corr_order_index ):
x, y, w, h = cv2.boundingRect(cnt)

View file

@ -196,7 +196,7 @@ class SBBPredict:
img_height = self.config_params_model['input_height']
img_width = self.config_params_model['input_width']
tree_xml, root_xml, bb_coord_printspace, file_name, \
tree_xml, root_xml, ps_bbox, file_name, \
id_paragraph, id_header, \
co_text_paragraph, co_text_header, \
tot_region_ref, x_len, y_len, index_tot_regions, \
@ -236,15 +236,13 @@ class SBBPredict:
img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1))
labels_con[:,:,i] = img_label[:,:,0]
if bb_coord_printspace:
#bb_coord_printspace[x,y,w,h,_,_]
x = bb_coord_printspace[0]
y = bb_coord_printspace[1]
w = bb_coord_printspace[2]
h = bb_coord_printspace[3]
labels_con = labels_con[y:y+h, x:x+w, :]
img_poly = img_poly[y:y+h, x:x+w, :]
img_header_and_sep = img_header_and_sep[y:y+h, x:x+w]
if ps_bbox:
labels_con = labels_con[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
img_poly = img_poly[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2], :]
img_header_and_sep = img_header_and_sep[ps_bbox[1]:ps_bbox[3],
ps_bbox[0]:ps_bbox[2]]