diff --git a/inference.py b/inference.py index 49bebf8..6054b01 100644 --- a/inference.py +++ b/inference.py @@ -32,6 +32,7 @@ class sbb_predict: self.image=image self.patches=patches self.save=save + self.save_layout=save_layout self.model_dir=model self.ground_truth=ground_truth self.task=task @@ -181,6 +182,7 @@ class sbb_predict: prediction = prediction * -1 prediction = prediction + 1 added_image = prediction * 255 + layout_only = None else: unique_classes = np.unique(prediction[:,:,0]) rgb_colors = {'0' : [255, 255, 255], @@ -200,26 +202,26 @@ class sbb_predict: '14' : [255, 125, 125], '15' : [255, 0, 255]} - output = np.zeros(prediction.shape) + layout_only = np.zeros(prediction.shape) for unq_class in unique_classes: rgb_class_unique = rgb_colors[str(int(unq_class))] - output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] - output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] - output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] + layout_only[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] + layout_only[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] + layout_only[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] - img = self.resize_image(img, output.shape[0], output.shape[1]) + img = self.resize_image(img, layout_only.shape[0], layout_only.shape[1]) - output = output.astype(np.int32) + layout_only = layout_only.astype(np.int32) img = img.astype(np.int32) - added_image = cv2.addWeighted(img,0.5,output,0.1,0) + added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) - return added_image, output + return added_image, layout_only def predict(self): self.start_new_session_and_model() @@ -559,13 +561,12 @@ class sbb_predict: pass elif self.task == 'enhancement': if self.save: - print(self.save) cv2.imwrite(self.save,res) else: - img_seg_overlayed, only_prediction = self.visualize_model_output(res, self.img_org, self.task) + img_seg_overlayed, only_layout = self.visualize_model_output(res, self.img_org, self.task) if self.save: cv2.imwrite(self.save,img_seg_overlayed) - cv2.imwrite('./layout.png', only_prediction) + cv2.imwrite(self.save_layout, only_layout) if self.ground_truth: gt_img=cv2.imread(self.ground_truth) @@ -595,6 +596,11 @@ class sbb_predict: "-s", help="save prediction as a png file in current folder.", ) +@click.option( + "--save_layout", + "-sl", + help="save layout prediction only as a png file in current folder.", +) @click.option( "--model", "-m", @@ -618,7 +624,7 @@ class sbb_predict: "-min", help="min area size of regions considered for reading order detection. The default value is zero and means that all text regions are considered for reading order.", ) -def main(image, model, patches, save, ground_truth, xml_file, out, min_area): +def main(image, model, patches, save, save_layout, ground_truth, xml_file, out, min_area): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] @@ -626,7 +632,7 @@ def main(image, model, patches, save, ground_truth, xml_file, out, min_area): if not save: print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") sys.exit(1) - x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file, out, min_area) + x=sbb_predict(image, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area) x.run() if __name__=="__main__":