mirror of
https://github.com/qurator-spk/sbb_pixelwise_segmentation.git
synced 2025-06-14 14:20:06 +02:00
inference updated
This commit is contained in:
commit
5fbe941f53
2 changed files with 12 additions and 8 deletions
|
@ -219,7 +219,7 @@ class sbb_predict:
|
||||||
|
|
||||||
added_image = cv2.addWeighted(img,0.5,output,0.1,0)
|
added_image = cv2.addWeighted(img,0.5,output,0.1,0)
|
||||||
|
|
||||||
return added_image
|
return added_image, output
|
||||||
|
|
||||||
def predict(self):
|
def predict(self):
|
||||||
self.start_new_session_and_model()
|
self.start_new_session_and_model()
|
||||||
|
@ -444,7 +444,7 @@ class sbb_predict:
|
||||||
|
|
||||||
if img.shape[1] < self.img_width:
|
if img.shape[1] < self.img_width:
|
||||||
img = cv2.resize(img, (self.img_height, img.shape[0]), interpolation=cv2.INTER_NEAREST)
|
img = cv2.resize(img, (self.img_height, img.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
margin = int(0 * self.img_width)
|
margin = int(0.1 * self.img_width)
|
||||||
width_mid = self.img_width - 2 * margin
|
width_mid = self.img_width - 2 * margin
|
||||||
height_mid = self.img_height - 2 * margin
|
height_mid = self.img_height - 2 * margin
|
||||||
img = img / float(255.0)
|
img = img / float(255.0)
|
||||||
|
@ -562,9 +562,10 @@ class sbb_predict:
|
||||||
print(self.save)
|
print(self.save)
|
||||||
cv2.imwrite(self.save,res)
|
cv2.imwrite(self.save,res)
|
||||||
else:
|
else:
|
||||||
img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task)
|
img_seg_overlayed, only_prediction = self.visualize_model_output(res, self.img_org, self.task)
|
||||||
if self.save:
|
if self.save:
|
||||||
cv2.imwrite(self.save,img_seg_overlayed)
|
cv2.imwrite(self.save,img_seg_overlayed)
|
||||||
|
cv2.imwrite('./layout.png', only_prediction)
|
||||||
|
|
||||||
if self.ground_truth:
|
if self.ground_truth:
|
||||||
gt_img=cv2.imread(self.ground_truth)
|
gt_img=cv2.imread(self.ground_truth)
|
||||||
|
|
13
utils.py
13
utils.py
|
@ -599,12 +599,15 @@ def provide_patches(imgs_list_train, segs_list_train, dir_img, dir_seg, dir_flow
|
||||||
indexer += 1
|
indexer += 1
|
||||||
if brightening:
|
if brightening:
|
||||||
for factor in brightness:
|
for factor in brightness:
|
||||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
try:
|
||||||
(resize_image(do_brightening(dir_img + '/' +im, factor), input_height, input_width)))
|
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||||
|
(resize_image(do_brightening(dir_img + '/' +im, factor), input_height, input_width)))
|
||||||
|
|
||||||
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
cv2.imwrite(dir_flow_train_labels + '/img_' + str(indexer) + '.png',
|
||||||
resize_image(cv2.imread(dir_of_label_file), input_height, input_width))
|
resize_image(cv2.imread(dir_of_label_file), input_height, input_width))
|
||||||
indexer += 1
|
indexer += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if binarization:
|
if binarization:
|
||||||
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
cv2.imwrite(dir_flow_train_imgs + '/img_' + str(indexer) + '.png',
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue