scaling and cropping of labels and org images

This commit is contained in:
vahidrezanezhad 2024-05-30 16:59:50 +02:00
parent 4640d9f2dc
commit 821290c464
3 changed files with 145 additions and 19 deletions

View file

@ -115,11 +115,15 @@ def update_region_contours(co_text, img_boundary, erosion_rate, dilation_rate, y
img_boundary[:,:][boundary[:,:]==1] =1
return co_text_eroded, img_boundary
def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params):
def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_file, config_params, printspace, dir_images, dir_out_images):
"""
Reading the page xml files and write the ground truth images into given output directory.
"""
## to do: add footnote to text regions
if dir_images:
ls_org_imgs = os.listdir(dir_images)
ls_org_imgs_stem = [item.split('.')[0] for item in ls_org_imgs]
for index in tqdm(range(len(gt_list))):
#try:
tree1 = ET.parse(dir_in+'/'+gt_list[index], parser = ET.XMLParser(encoding = 'iso-8859-5'))
@ -133,6 +137,72 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
y_len=int(jj.attrib['imageHeight'])
x_len=int(jj.attrib['imageWidth'])
if 'columns_width' in list(config_params.keys()):
columns_width_dict = config_params['columns_width']
metadata_element = root1.find(link+'Metadata')
comment_is_sub_element = False
for child in metadata_element:
tag2 = child.tag
if tag2.endswith('}Comments') or tag2.endswith('}comments'):
text_comments = child.text
num_col = int(text_comments.split('num_col')[1])
comment_is_sub_element = True
if not comment_is_sub_element:
num_col = None
if num_col:
x_new = columns_width_dict[str(num_col)]
y_new = int ( x_new * (y_len / float(x_len)) )
if printspace:
region_tags = np.unique([x for x in alltags if x.endswith('PrintSpace') or x.endswith('Border')])
co_use_case = []
for tag in region_tags:
tag_endings = ['}PrintSpace','}Border']
if tag.endswith(tag_endings[0]) or tag.endswith(tag_endings[1]):
for nn in root1.iter(tag):
c_t_in = []
sumi = 0
for vv in nn.iter():
# check the format of coords
if vv.tag == link + 'Coords':
coords = bool(vv.attrib)
if coords:
p_h = vv.attrib['points'].split(' ')
c_t_in.append(
np.array([[int(x.split(',')[0]), int(x.split(',')[1])] for x in p_h]))
break
else:
pass
if vv.tag == link + 'Point':
c_t_in.append([int(float(vv.attrib['x'])), int(float(vv.attrib['y']))])
sumi += 1
elif vv.tag != link + 'Point' and sumi >= 1:
break
co_use_case.append(np.array(c_t_in))
img = np.zeros((y_len, x_len, 3))
img_poly = cv2.fillPoly(img, pts=co_use_case, color=(1, 1, 1))
img_poly = img_poly.astype(np.uint8)
imgray = cv2.cvtColor(img_poly, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
bb_xywh = [x, y, w, h]
if config_file and (config_params['use_case']=='textline' or config_params['use_case']=='word' or config_params['use_case']=='glyph' or config_params['use_case']=='printspace'):
keys = list(config_params.keys())
if "artificial_class_label" in keys:
@ -186,7 +256,6 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
co_use_case.append(np.array(c_t_in))
if "artificial_class_label" in keys:
img_boundary = np.zeros((y_len, x_len))
erosion_rate = 1
@ -205,12 +274,32 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
img_poly[:,:,0][img_boundary[:,:]==1] = artificial_class_rgb_color[0]
img_poly[:,:,1][img_boundary[:,:]==1] = artificial_class_rgb_color[1]
img_poly[:,:,2][img_boundary[:,:]==1] = artificial_class_rgb_color[2]
if printspace and config_params['use_case']!='printspace':
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_poly = resize_image(img_poly, y_new, x_new)
try:
cv2.imwrite(output_dir + '/' + gt_list[index].split('-')[1].split('.')[0] + '.png',
img_poly)
xml_file_stem = gt_list[index].split('-')[1].split('.')[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
except:
cv2.imwrite(output_dir + '/' + gt_list[index].split('.')[0] + '.png', img_poly)
xml_file_stem = gt_list[index].split('.')[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
if dir_images:
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)]
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace and config_params['use_case']!='printspace':
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
if 'columns_width' in list(config_params.keys()) and num_col and config_params['use_case']!='printspace':
img_org = resize_image(img_org, y_new, x_new)
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)
if config_file and config_params['use_case']=='layout':
@ -616,11 +705,31 @@ def get_images_of_ground_truth(gt_list, dir_in, output_dir, output_type, config_
if printspace:
img_poly = img_poly[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
try:
cv2.imwrite(output_dir+'/'+gt_list[index].split('-')[1].split('.')[0]+'.png',img_poly )
if 'columns_width' in list(config_params.keys()) and num_col:
img_poly = resize_image(img_poly, y_new, x_new)
try:
xml_file_stem = gt_list[index].split('-')[1].split('.')[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
except:
cv2.imwrite(output_dir+'/'+gt_list[index].split('.')[0]+'.png',img_poly )
xml_file_stem = gt_list[index].split('.')[0]
cv2.imwrite(os.path.join(output_dir, xml_file_stem + '.png'), img_poly)
if dir_images:
org_image_name = ls_org_imgs[ls_org_imgs_stem.index(xml_file_stem)]
img_org = cv2.imread(os.path.join(dir_images, org_image_name))
if printspace:
img_org = img_org[bb_xywh[1]:bb_xywh[1]+bb_xywh[3], bb_xywh[0]:bb_xywh[0]+bb_xywh[2], :]
if 'columns_width' in list(config_params.keys()) and num_col:
img_org = resize_image(img_org, y_new, x_new)
cv2.imwrite(os.path.join(dir_out_images, org_image_name), img_org)